mirror of
https://github.com/joaovitoriasilva/endurain.git
synced 2026-01-07 23:13:57 -05:00
Add tests for identity_providers module and update .gitignore
Added comprehensive unit tests for the identity_providers module, including CRUD operations, schema validation, and utility functions. Updated .gitignore to exclude deeper __pycache__ directories. Removed session test files and old __pycache__ files from the repository.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -15,6 +15,9 @@ backend/app/*/*/*/__pycache__/
|
||||
backend/app/*.pyc
|
||||
backend/tests/__pycache__/
|
||||
backend/tests/*/__pycache__/
|
||||
backend/tests/*/*/__pycache__/
|
||||
backend/tests/*/*/*/__pycache__/
|
||||
backend/tests/*/*/*/*/__pycache__/
|
||||
|
||||
# Tests
|
||||
backend/.coverage
|
||||
|
||||
1
backend/tests/auth/identity_providers/__init__.py
Normal file
1
backend/tests/auth/identity_providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for identity_providers module."""
|
||||
677
backend/tests/auth/identity_providers/test_crud.py
Normal file
677
backend/tests/auth/identity_providers/test_crud.py
Normal file
@@ -0,0 +1,677 @@
|
||||
"""Tests for identity_providers.crud module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi import HTTPException
|
||||
|
||||
from auth.identity_providers import crud as idp_crud
|
||||
from auth.identity_providers.schema import (
|
||||
IdentityProviderCreate,
|
||||
IdentityProviderUpdate,
|
||||
)
|
||||
from auth.identity_providers.models import IdentityProvider
|
||||
|
||||
|
||||
class TestGetIdentityProvider:
|
||||
"""Test suite for get_identity_provider function."""
|
||||
|
||||
def test_get_identity_provider_success(self, mock_db):
|
||||
"""Test successfully retrieving an identity provider by ID.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Correct database query is executed
|
||||
- Identity provider is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.id = 1
|
||||
mock_idp.name = "Test Provider"
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_idp
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_identity_provider(1, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == mock_idp
|
||||
mock_db.query.assert_called_once()
|
||||
mock_db.query.return_value.filter.assert_called_once()
|
||||
|
||||
def test_get_identity_provider_not_found(self, mock_db):
|
||||
"""Test retrieving non-existent identity provider.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- None is returned when provider not found
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_identity_provider(999, mock_db)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_identity_provider_database_error(self, mock_db):
|
||||
"""Test get_identity_provider handles database errors.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised on database error
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.get_identity_provider(1, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "Internal Server Error" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
class TestGetIdentityProviderBySlug:
|
||||
"""Test suite for get_identity_provider_by_slug function."""
|
||||
|
||||
def test_get_identity_provider_by_slug_success(self, mock_db):
|
||||
"""Test successfully retrieving an identity provider by slug.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Correct database query is executed
|
||||
- Identity provider is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.slug = "test-provider"
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_idp
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_identity_provider_by_slug("test-provider", mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == mock_idp
|
||||
mock_db.query.assert_called_once()
|
||||
|
||||
def test_get_identity_provider_by_slug_not_found(self, mock_db):
|
||||
"""Test retrieving non-existent identity provider by slug.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- None is returned when provider not found
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_identity_provider_by_slug("nonexistent", mock_db)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_identity_provider_by_slug_database_error(self, mock_db):
|
||||
"""Test get_identity_provider_by_slug handles database errors.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised on database error
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.get_identity_provider_by_slug("test", mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
|
||||
class TestGetAllIdentityProviders:
|
||||
"""Test suite for get_all_identity_providers function."""
|
||||
|
||||
def test_get_all_identity_providers_success(self, mock_db):
|
||||
"""Test successfully retrieving all identity providers.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- All providers are returned ordered by name
|
||||
"""
|
||||
# Arrange
|
||||
mock_idps = [MagicMock(spec=IdentityProvider) for _ in range(3)]
|
||||
mock_db.query.return_value.order_by.return_value.all.return_value = mock_idps
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_all_identity_providers(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == mock_idps
|
||||
assert len(result) == 3
|
||||
mock_db.query.assert_called_once()
|
||||
mock_db.query.return_value.order_by.assert_called_once()
|
||||
|
||||
def test_get_all_identity_providers_empty(self, mock_db):
|
||||
"""Test retrieving all identity providers when none exist.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Empty list is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.return_value.order_by.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_all_identity_providers(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_get_all_identity_providers_database_error(self, mock_db):
|
||||
"""Test get_all_identity_providers handles database errors.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised on database error
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.get_all_identity_providers(mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
|
||||
class TestGetEnabledProviders:
|
||||
"""Test suite for get_enabled_providers function."""
|
||||
|
||||
def test_get_enabled_providers_success(self, mock_db):
|
||||
"""Test successfully retrieving enabled identity providers.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Only enabled providers are returned
|
||||
- Results are ordered by name
|
||||
"""
|
||||
# Arrange
|
||||
mock_idps = [MagicMock(spec=IdentityProvider) for _ in range(2)]
|
||||
(
|
||||
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value
|
||||
) = mock_idps
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_enabled_providers(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == mock_idps
|
||||
assert len(result) == 2
|
||||
mock_db.query.assert_called_once()
|
||||
mock_db.query.return_value.filter.assert_called_once()
|
||||
|
||||
def test_get_enabled_providers_empty(self, mock_db):
|
||||
"""Test retrieving enabled providers when none exist.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Empty list is returned
|
||||
"""
|
||||
# Arrange
|
||||
(
|
||||
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value
|
||||
) = []
|
||||
|
||||
# Act
|
||||
result = idp_crud.get_enabled_providers(mock_db)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_get_enabled_providers_database_error(self, mock_db):
|
||||
"""Test get_enabled_providers handles database errors.
|
||||
|
||||
Args:
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised on database error
|
||||
"""
|
||||
# Arrange
|
||||
mock_db.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.get_enabled_providers(mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
|
||||
class TestCreateIdentityProvider:
|
||||
"""Test suite for create_identity_provider function."""
|
||||
|
||||
@patch("auth.identity_providers.crud.idp_models.IdentityProvider")
|
||||
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
|
||||
def test_create_identity_provider_success(
|
||||
self, mock_get_by_slug, mock_encrypt, mock_idp_model, mock_db
|
||||
):
|
||||
"""Test successfully creating a new identity provider.
|
||||
|
||||
Args:
|
||||
mock_get_by_slug: Mocked get_identity_provider_by_slug function
|
||||
mock_encrypt: Mocked encryption function
|
||||
mock_idp_model: Mocked IdentityProvider model
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Identity provider is created with encrypted credentials
|
||||
- Database commit and refresh are called
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_by_slug.return_value = None # No existing provider
|
||||
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
|
||||
|
||||
mock_idp_instance = MagicMock()
|
||||
mock_idp_instance.id = 1
|
||||
mock_idp_instance.name = "Test Provider"
|
||||
mock_idp_model.return_value = mock_idp_instance
|
||||
|
||||
idp_data = IdentityProviderCreate(
|
||||
name="Test Provider",
|
||||
slug="test-provider",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-secret",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = idp_crud.create_identity_provider(idp_data, mock_db)
|
||||
|
||||
# Assert
|
||||
mock_get_by_slug.assert_called_once_with("test-provider", mock_db)
|
||||
assert mock_encrypt.call_count == 2 # Called for client_id and client_secret
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
mock_db.refresh.assert_called_once()
|
||||
assert result == mock_idp_instance
|
||||
|
||||
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
|
||||
def test_create_identity_provider_slug_exists(self, mock_get_by_slug, mock_db):
|
||||
"""Test creating identity provider with existing slug.
|
||||
|
||||
Args:
|
||||
mock_get_by_slug: Mocked get_identity_provider_by_slug function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 409 status is raised
|
||||
- Error message indicates slug conflict
|
||||
"""
|
||||
# Arrange
|
||||
mock_existing = MagicMock(spec=IdentityProvider)
|
||||
mock_get_by_slug.return_value = mock_existing
|
||||
|
||||
idp_data = IdentityProviderCreate(
|
||||
name="Test Provider",
|
||||
slug="existing-slug",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-secret",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.create_identity_provider(idp_data, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "already exists" in str(exc_info.value.detail)
|
||||
|
||||
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
|
||||
def test_create_identity_provider_database_error(
|
||||
self, mock_get_by_slug, mock_encrypt, mock_db
|
||||
):
|
||||
"""Test create_identity_provider handles database errors.
|
||||
|
||||
Args:
|
||||
mock_get_by_slug: Mocked get_identity_provider_by_slug function
|
||||
mock_encrypt: Mocked encryption function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised
|
||||
- Database rollback is called
|
||||
"""
|
||||
# Arrange
|
||||
mock_get_by_slug.return_value = None
|
||||
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
|
||||
mock_db.commit.side_effect = Exception("Database error")
|
||||
|
||||
idp_data = IdentityProviderCreate(
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id="client-id",
|
||||
client_secret="secret",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.create_identity_provider(idp_data, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
mock_db.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestUpdateIdentityProvider:
|
||||
"""Test suite for update_identity_provider function."""
|
||||
|
||||
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_update_identity_provider_success(
|
||||
self, mock_get, mock_get_by_slug, mock_encrypt, mock_db
|
||||
):
|
||||
"""Test successfully updating an identity provider.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_get_by_slug: Mocked get_identity_provider_by_slug function
|
||||
mock_encrypt: Mocked encryption function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Identity provider is updated
|
||||
- Encrypted fields are encrypted
|
||||
- Database commit and refresh are called
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.id = 1
|
||||
mock_idp.slug = "original-slug"
|
||||
mock_get.return_value = mock_idp
|
||||
mock_get_by_slug.return_value = None
|
||||
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
|
||||
|
||||
idp_data = IdentityProviderUpdate(
|
||||
name="Updated Provider",
|
||||
slug="updated-slug",
|
||||
client_id="new-client-id",
|
||||
client_secret="new-secret",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = idp_crud.update_identity_provider(1, idp_data, mock_db)
|
||||
|
||||
# Assert
|
||||
mock_get.assert_called_once_with(1, mock_db)
|
||||
mock_get_by_slug.assert_called_once_with("updated-slug", mock_db)
|
||||
assert mock_encrypt.call_count == 2
|
||||
mock_db.commit.assert_called_once()
|
||||
mock_db.refresh.assert_called_once()
|
||||
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_update_identity_provider_not_found(self, mock_get, mock_db):
|
||||
"""Test updating non-existent identity provider.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 404 status is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_get.return_value = None
|
||||
idp_data = IdentityProviderUpdate(
|
||||
name="Updated", slug="updated-slug", client_id="client-123"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.update_identity_provider(999, idp_data, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "not found" in str(exc_info.value.detail)
|
||||
|
||||
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_update_identity_provider_slug_conflict(
|
||||
self, mock_get, mock_get_by_slug, mock_db
|
||||
):
|
||||
"""Test updating identity provider with conflicting slug.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_get_by_slug: Mocked get_identity_provider_by_slug function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 409 status is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.slug = "original-slug"
|
||||
mock_get.return_value = mock_idp
|
||||
|
||||
mock_existing = MagicMock(spec=IdentityProvider)
|
||||
mock_existing.slug = "existing-slug"
|
||||
mock_get_by_slug.return_value = mock_existing
|
||||
|
||||
idp_data = IdentityProviderUpdate(
|
||||
name="Test", slug="existing-slug", client_id="client-123"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.update_identity_provider(1, idp_data, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "already exists" in str(exc_info.value.detail)
|
||||
|
||||
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_update_identity_provider_without_slug_change(
|
||||
self, mock_get, mock_encrypt, mock_db
|
||||
):
|
||||
"""Test updating identity provider without changing slug.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_encrypt: Mocked encryption function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- get_identity_provider_by_slug is not called
|
||||
- Update succeeds
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.slug = "test-slug"
|
||||
mock_get.return_value = mock_idp
|
||||
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
|
||||
|
||||
idp_data = IdentityProviderUpdate(
|
||||
name="Updated Name", slug="test-slug", client_id="client-123"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = idp_crud.update_identity_provider(1, idp_data, mock_db)
|
||||
|
||||
# Assert
|
||||
mock_get.assert_called_once_with(1, mock_db)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_update_identity_provider_database_error(
|
||||
self, mock_get, mock_encrypt, mock_db
|
||||
):
|
||||
"""Test update_identity_provider handles database errors.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_encrypt: Mocked encryption function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised
|
||||
- Database rollback is called
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.slug = "test-slug"
|
||||
mock_get.return_value = mock_idp
|
||||
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
|
||||
mock_db.commit.side_effect = Exception("Database error")
|
||||
|
||||
# Use same slug to avoid conflict check
|
||||
idp_data = IdentityProviderUpdate(
|
||||
name="Updated", slug="test-slug", client_id="client-123"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.update_identity_provider(1, idp_data, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
mock_db.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestDeleteIdentityProvider:
|
||||
"""Test suite for delete_identity_provider function."""
|
||||
|
||||
@patch(
|
||||
"auth.identity_providers.crud.user_identity_providers_crud.check_user_identity_providers_by_idp_id"
|
||||
)
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_delete_identity_provider_success(
|
||||
self, mock_get, mock_check_users, mock_db
|
||||
):
|
||||
"""Test successfully deleting an identity provider.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_check_users: Mocked check for user links
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- Identity provider is deleted
|
||||
- Database commit is called
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_idp.id = 1
|
||||
mock_idp.name = "Test Provider"
|
||||
mock_get.return_value = mock_idp
|
||||
mock_check_users.return_value = None # No linked users
|
||||
|
||||
# Act
|
||||
idp_crud.delete_identity_provider(1, mock_db)
|
||||
|
||||
# Assert
|
||||
mock_get.assert_called_once_with(1, mock_db)
|
||||
mock_check_users.assert_called_once_with(1, mock_db)
|
||||
mock_db.delete.assert_called_once_with(mock_idp)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_delete_identity_provider_not_found(self, mock_get, mock_db):
|
||||
"""Test deleting non-existent identity provider.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 404 status is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.delete_identity_provider(999, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "not found" in str(exc_info.value.detail)
|
||||
|
||||
@patch(
|
||||
"auth.identity_providers.crud.user_identity_providers_crud.check_user_identity_providers_by_idp_id"
|
||||
)
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_delete_identity_provider_with_linked_users(
|
||||
self, mock_get, mock_check_users, mock_db
|
||||
):
|
||||
"""Test deleting identity provider with linked users.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_check_users: Mocked check for user links
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 409 status is raised
|
||||
- Error message indicates linked users
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_get.return_value = mock_idp
|
||||
mock_check_users.return_value = True # Has linked users
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.delete_identity_provider(1, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "linked users" in str(exc_info.value.detail)
|
||||
|
||||
@patch(
|
||||
"auth.identity_providers.crud.user_identity_providers_crud.check_user_identity_providers_by_idp_id"
|
||||
)
|
||||
@patch("auth.identity_providers.crud.get_identity_provider")
|
||||
def test_delete_identity_provider_database_error(
|
||||
self, mock_get, mock_check_users, mock_db
|
||||
):
|
||||
"""Test delete_identity_provider handles database errors.
|
||||
|
||||
Args:
|
||||
mock_get: Mocked get_identity_provider function
|
||||
mock_check_users: Mocked check for user links
|
||||
mock_db: Mocked database session
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 500 status is raised
|
||||
- Database rollback is called
|
||||
"""
|
||||
# Arrange
|
||||
mock_idp = MagicMock(spec=IdentityProvider)
|
||||
mock_get.return_value = mock_idp
|
||||
mock_check_users.return_value = None
|
||||
mock_db.delete.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
idp_crud.delete_identity_provider(1, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
mock_db.rollback.assert_called_once()
|
||||
812
backend/tests/auth/identity_providers/test_schema.py
Normal file
812
backend/tests/auth/identity_providers/test_schema.py
Normal file
@@ -0,0 +1,812 @@
|
||||
"""Tests for identity_providers.schema module."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from datetime import datetime
|
||||
|
||||
from auth.identity_providers.schema import (
|
||||
IdentityProviderBase,
|
||||
IdentityProviderCreate,
|
||||
IdentityProviderUpdate,
|
||||
IdentityProvider,
|
||||
IdentityProviderPublic,
|
||||
IdentityProviderTemplate,
|
||||
TokenExchangeRequest,
|
||||
TokenExchangeResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestIdentityProviderBase:
|
||||
"""Test suite for IdentityProviderBase schema."""
|
||||
|
||||
def test_valid_identity_provider_base(self):
|
||||
"""Test creating a valid IdentityProviderBase instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully
|
||||
- All fields have correct values
|
||||
- Default values are applied
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(
|
||||
name="Test Provider",
|
||||
slug="test-provider",
|
||||
provider_type="oidc",
|
||||
enabled=True,
|
||||
issuer_url="https://auth.example.com",
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
userinfo_endpoint="https://auth.example.com/userinfo",
|
||||
jwks_uri="https://auth.example.com/jwks",
|
||||
scopes="openid profile email",
|
||||
icon="test-icon",
|
||||
auto_create_users=True,
|
||||
sync_user_info=True,
|
||||
user_mapping={"username": ["preferred_username"], "email": ["email"]},
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.name == "Test Provider"
|
||||
assert idp.slug == "test-provider"
|
||||
assert idp.provider_type == "oidc"
|
||||
assert idp.enabled is True
|
||||
assert idp.issuer_url == "https://auth.example.com"
|
||||
assert idp.scopes == "openid profile email"
|
||||
assert idp.auto_create_users is True
|
||||
assert idp.sync_user_info is True
|
||||
assert idp.client_id == "test-client-id"
|
||||
|
||||
def test_identity_provider_base_with_defaults(self):
|
||||
"""Test IdentityProviderBase with default values.
|
||||
|
||||
Asserts:
|
||||
- Default provider_type is 'oidc'
|
||||
- Default enabled is False
|
||||
- Default scopes is 'openid profile email'
|
||||
- Default auto_create_users is True
|
||||
- Default sync_user_info is True
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(name="Test", slug="test", client_id="client-123")
|
||||
|
||||
# Assert
|
||||
assert idp.provider_type == "oidc"
|
||||
assert idp.enabled is False
|
||||
assert idp.scopes == "openid profile email"
|
||||
assert idp.auto_create_users is True
|
||||
assert idp.sync_user_info is True
|
||||
|
||||
def test_slug_validation_lowercase_alphanumeric_hyphens(self):
|
||||
"""Test slug validation accepts valid characters.
|
||||
|
||||
Asserts:
|
||||
- Lowercase letters, numbers, and hyphens are accepted
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(
|
||||
name="Test", slug="test-provider-123", client_id="client-123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.slug == "test-provider-123"
|
||||
|
||||
def test_slug_validation_uppercase_rejected(self):
|
||||
"""Test slug validation rejects uppercase letters.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised for uppercase characters
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderBase(
|
||||
name="Test", slug="Test-Provider", client_id="client-123"
|
||||
)
|
||||
|
||||
assert "Slug must contain only lowercase letters" in str(exc_info.value)
|
||||
|
||||
def test_slug_validation_special_characters_rejected(self):
|
||||
"""Test slug validation rejects special characters.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised for special characters
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderBase(
|
||||
name="Test", slug="test_provider", client_id="client-123"
|
||||
)
|
||||
|
||||
assert "Slug must contain only lowercase letters" in str(exc_info.value)
|
||||
|
||||
def test_provider_type_validation_oidc(self):
|
||||
"""Test provider_type validation accepts 'oidc'.
|
||||
|
||||
Asserts:
|
||||
- 'oidc' is accepted as valid provider type
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(
|
||||
name="Test", slug="test", provider_type="oidc", client_id="client-123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.provider_type == "oidc"
|
||||
|
||||
def test_provider_type_validation_oauth2(self):
|
||||
"""Test provider_type validation accepts 'oauth2'.
|
||||
|
||||
Asserts:
|
||||
- 'oauth2' is accepted as valid provider type
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(
|
||||
name="Test", slug="test", provider_type="oauth2", client_id="client-123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.provider_type == "oauth2"
|
||||
|
||||
def test_provider_type_validation_saml(self):
|
||||
"""Test provider_type validation accepts 'saml'.
|
||||
|
||||
Asserts:
|
||||
- 'saml' is accepted as valid provider type
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(
|
||||
name="Test", slug="test", provider_type="saml", client_id="client-123"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.provider_type == "saml"
|
||||
|
||||
def test_provider_type_validation_invalid(self):
|
||||
"""Test provider_type validation rejects invalid types.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised for invalid provider type
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderBase(
|
||||
name="Test",
|
||||
slug="test",
|
||||
provider_type="invalid",
|
||||
client_id="client-123",
|
||||
)
|
||||
|
||||
assert "Provider type must be one of" in str(exc_info.value)
|
||||
|
||||
def test_name_max_length_validation(self):
|
||||
"""Test name field max length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when name exceeds 100 characters
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderBase(name="x" * 101, slug="test", client_id="client-123")
|
||||
|
||||
assert "String should have at most 100 characters" in str(exc_info.value)
|
||||
|
||||
def test_name_min_length_validation(self):
|
||||
"""Test name field min length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when name is empty
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderBase(name="", slug="test", client_id="client-123")
|
||||
|
||||
assert "String should have at least 1 character" in str(exc_info.value)
|
||||
|
||||
def test_slug_max_length_validation(self):
|
||||
"""Test slug field max length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when slug exceeds 50 characters
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderBase(name="Test", slug="x" * 51, client_id="client-123")
|
||||
|
||||
assert "String should have at most 50 characters" in str(exc_info.value)
|
||||
|
||||
def test_optional_fields_can_be_none(self):
|
||||
"""Test optional fields can be None.
|
||||
|
||||
Asserts:
|
||||
- issuer_url, authorization_endpoint, token_endpoint can be None
|
||||
- userinfo_endpoint, jwks_uri, icon can be None
|
||||
- user_mapping can be None
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderBase(
|
||||
name="Test",
|
||||
slug="test",
|
||||
issuer_url=None,
|
||||
authorization_endpoint=None,
|
||||
token_endpoint=None,
|
||||
userinfo_endpoint=None,
|
||||
jwks_uri=None,
|
||||
icon=None,
|
||||
user_mapping=None,
|
||||
client_id="client-123",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.issuer_url is None
|
||||
assert idp.authorization_endpoint is None
|
||||
assert idp.token_endpoint is None
|
||||
assert idp.userinfo_endpoint is None
|
||||
assert idp.jwks_uri is None
|
||||
assert idp.icon is None
|
||||
assert idp.user_mapping is None
|
||||
|
||||
|
||||
class TestIdentityProviderCreate:
|
||||
"""Test suite for IdentityProviderCreate schema."""
|
||||
|
||||
def test_valid_identity_provider_create(self):
|
||||
"""Test creating a valid IdentityProviderCreate instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully
|
||||
- client_secret is required and set
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderCreate(
|
||||
name="Test Provider",
|
||||
slug="test-provider",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.name == "Test Provider"
|
||||
assert idp.slug == "test-provider"
|
||||
assert idp.client_id == "test-client-id"
|
||||
assert idp.client_secret == "test-client-secret"
|
||||
|
||||
def test_client_secret_required(self):
|
||||
"""Test client_secret is required for IdentityProviderCreate.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when client_secret is missing
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderCreate(name="Test", slug="test", client_id="client-123")
|
||||
|
||||
assert "client_secret" in str(exc_info.value)
|
||||
assert "Field required" in str(exc_info.value)
|
||||
|
||||
def test_client_secret_min_length(self):
|
||||
"""Test client_secret minimum length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when client_secret is empty
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderCreate(
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id="client-123",
|
||||
client_secret="",
|
||||
)
|
||||
|
||||
assert "String should have at least 1 character" in str(exc_info.value)
|
||||
|
||||
def test_client_secret_max_length(self):
|
||||
"""Test client_secret maximum length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when client_secret exceeds 512 characters
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderCreate(
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id="client-123",
|
||||
client_secret="x" * 513,
|
||||
)
|
||||
|
||||
assert "String should have at most 512 characters" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestIdentityProviderUpdate:
|
||||
"""Test suite for IdentityProviderUpdate schema."""
|
||||
|
||||
def test_valid_identity_provider_update(self):
|
||||
"""Test creating a valid IdentityProviderUpdate instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully
|
||||
- All fields can be updated
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderUpdate(
|
||||
name="Updated Provider",
|
||||
slug="updated-provider",
|
||||
provider_type="oauth2",
|
||||
enabled=True,
|
||||
client_id="updated-client-id",
|
||||
client_secret="updated-secret",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.name == "Updated Provider"
|
||||
assert idp.slug == "updated-provider"
|
||||
assert idp.provider_type == "oauth2"
|
||||
assert idp.enabled is True
|
||||
assert idp.client_id == "updated-client-id"
|
||||
assert idp.client_secret == "updated-secret"
|
||||
|
||||
def test_client_secret_optional(self):
|
||||
"""Test client_secret is optional for IdentityProviderUpdate.
|
||||
|
||||
Asserts:
|
||||
- Instance can be created without client_secret
|
||||
- client_secret defaults to None
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderUpdate(name="Test", slug="test", client_id="client-123")
|
||||
|
||||
# Assert
|
||||
assert idp.client_secret is None
|
||||
|
||||
def test_client_secret_can_be_none(self):
|
||||
"""Test client_secret can explicitly be None.
|
||||
|
||||
Asserts:
|
||||
- client_secret can be set to None
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderUpdate(
|
||||
name="Test", slug="test", client_id="client-123", client_secret=None
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.client_secret is None
|
||||
|
||||
def test_client_secret_min_length_when_provided(self):
|
||||
"""Test client_secret minimum length when provided.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when client_secret is empty string
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderUpdate(
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id="client-123",
|
||||
client_secret="",
|
||||
)
|
||||
|
||||
assert "String should have at least 1 character" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestIdentityProvider:
|
||||
"""Test suite for IdentityProvider schema."""
|
||||
|
||||
def test_valid_identity_provider(self):
|
||||
"""Test creating a valid IdentityProvider instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully
|
||||
- All fields including id and timestamps are set
|
||||
"""
|
||||
# Arrange
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Act
|
||||
idp = IdentityProvider(
|
||||
id=1,
|
||||
name="Test Provider",
|
||||
slug="test-provider",
|
||||
provider_type="oidc",
|
||||
enabled=True,
|
||||
client_id="test-client-id",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.id == 1
|
||||
assert idp.name == "Test Provider"
|
||||
assert idp.slug == "test-provider"
|
||||
assert idp.provider_type == "oidc"
|
||||
assert idp.enabled is True
|
||||
assert idp.client_id == "test-client-id"
|
||||
assert idp.created_at == now
|
||||
assert idp.updated_at == now
|
||||
|
||||
def test_identity_provider_requires_id(self):
|
||||
"""Test id field is required for IdentityProvider.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when id is missing
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProvider(
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id="client-123",
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert "id" in str(exc_info.value)
|
||||
assert "Field required" in str(exc_info.value)
|
||||
|
||||
def test_identity_provider_requires_timestamps(self):
|
||||
"""Test created_at and updated_at are required.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when timestamps are missing
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProvider(id=1, name="Test", slug="test", client_id="client-123")
|
||||
|
||||
assert "created_at" in str(exc_info.value) or "updated_at" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_serialize_client_id_decrypts_encrypted_value(self):
|
||||
"""Test serialize_client_id decrypts encrypted client_id.
|
||||
|
||||
Asserts:
|
||||
- Encrypted client_id (starting with 'gAAAAAB') triggers decryption
|
||||
"""
|
||||
# Arrange
|
||||
now = datetime.utcnow()
|
||||
encrypted_id = "gAAAAABtest_encrypted_token"
|
||||
|
||||
# Act
|
||||
idp = IdentityProvider(
|
||||
id=1,
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id=encrypted_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Assert - This will be tested with actual encryption in integration tests
|
||||
assert idp.client_id == encrypted_id
|
||||
|
||||
def test_serialize_client_id_preserves_unencrypted_value(self):
|
||||
"""Test serialize_client_id preserves non-encrypted client_id.
|
||||
|
||||
Asserts:
|
||||
- Non-encrypted client_id is returned as-is
|
||||
"""
|
||||
# Arrange
|
||||
now = datetime.utcnow()
|
||||
plain_id = "plain-client-id"
|
||||
|
||||
# Act
|
||||
idp = IdentityProvider(
|
||||
id=1,
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id=plain_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.client_id == plain_id
|
||||
|
||||
def test_serialize_client_id_handles_none(self):
|
||||
"""Test serialize_client_id handles None client_id.
|
||||
|
||||
Asserts:
|
||||
- None client_id returns None
|
||||
"""
|
||||
# Arrange
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Act
|
||||
idp = IdentityProvider(
|
||||
id=1,
|
||||
name="Test",
|
||||
slug="test",
|
||||
client_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.client_id is None
|
||||
|
||||
|
||||
class TestIdentityProviderPublic:
|
||||
"""Test suite for IdentityProviderPublic schema."""
|
||||
|
||||
def test_valid_identity_provider_public(self):
|
||||
"""Test creating a valid IdentityProviderPublic instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully
|
||||
- Only public fields are present
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderPublic(
|
||||
id=1, name="Test Provider", slug="test-provider", icon="test-icon"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert idp.id == 1
|
||||
assert idp.name == "Test Provider"
|
||||
assert idp.slug == "test-provider"
|
||||
assert idp.icon == "test-icon"
|
||||
|
||||
def test_identity_provider_public_icon_optional(self):
|
||||
"""Test icon field is optional for IdentityProviderPublic.
|
||||
|
||||
Asserts:
|
||||
- Instance can be created without icon
|
||||
- icon defaults to None
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderPublic(id=1, name="Test", slug="test")
|
||||
|
||||
# Assert
|
||||
assert idp.icon is None
|
||||
|
||||
def test_identity_provider_public_no_sensitive_fields(self):
|
||||
"""Test IdentityProviderPublic doesn't have sensitive fields.
|
||||
|
||||
Asserts:
|
||||
- client_id, client_secret are not present
|
||||
"""
|
||||
# Arrange & Act
|
||||
idp = IdentityProviderPublic(id=1, name="Test", slug="test")
|
||||
|
||||
# Assert
|
||||
assert not hasattr(idp, "client_id")
|
||||
assert not hasattr(idp, "client_secret")
|
||||
|
||||
|
||||
class TestIdentityProviderTemplate:
|
||||
"""Test suite for IdentityProviderTemplate schema."""
|
||||
|
||||
def test_valid_identity_provider_template(self):
|
||||
"""Test creating a valid IdentityProviderTemplate instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully
|
||||
- All fields are set correctly
|
||||
"""
|
||||
# Arrange & Act
|
||||
template = IdentityProviderTemplate(
|
||||
template_id="keycloak",
|
||||
name="Keycloak",
|
||||
provider_type="oidc",
|
||||
issuer_url="https://keycloak.example.com/realms/master",
|
||||
scopes="openid profile email",
|
||||
icon="keycloak",
|
||||
user_mapping={"username": ["preferred_username"], "email": ["email"]},
|
||||
description="Keycloak OIDC provider",
|
||||
configuration_notes="Setup instructions here",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert template.template_id == "keycloak"
|
||||
assert template.name == "Keycloak"
|
||||
assert template.provider_type == "oidc"
|
||||
assert template.issuer_url == "https://keycloak.example.com/realms/master"
|
||||
assert template.scopes == "openid profile email"
|
||||
assert template.icon == "keycloak"
|
||||
assert template.user_mapping == {
|
||||
"username": ["preferred_username"],
|
||||
"email": ["email"],
|
||||
}
|
||||
assert template.description == "Keycloak OIDC provider"
|
||||
assert template.configuration_notes == "Setup instructions here"
|
||||
|
||||
def test_template_required_fields(self):
|
||||
"""Test required fields for IdentityProviderTemplate.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when required fields are missing
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IdentityProviderTemplate()
|
||||
|
||||
error_str = str(exc_info.value)
|
||||
assert "template_id" in error_str or "Field required" in error_str
|
||||
|
||||
def test_template_optional_fields(self):
|
||||
"""Test optional fields for IdentityProviderTemplate.
|
||||
|
||||
Asserts:
|
||||
- Template can be created with only required fields
|
||||
- Optional fields default to None
|
||||
"""
|
||||
# Arrange & Act
|
||||
template = IdentityProviderTemplate(
|
||||
template_id="custom",
|
||||
name="Custom Provider",
|
||||
provider_type="oidc",
|
||||
scopes="openid",
|
||||
description="Custom provider",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert template.issuer_url is None
|
||||
assert template.icon is None
|
||||
assert template.user_mapping is None
|
||||
assert template.configuration_notes is None
|
||||
|
||||
|
||||
class TestTokenExchangeRequest:
|
||||
"""Test suite for TokenExchangeRequest schema."""
|
||||
|
||||
def test_valid_token_exchange_request(self):
|
||||
"""Test creating a valid TokenExchangeRequest instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully with valid code_verifier
|
||||
"""
|
||||
# Arrange
|
||||
# Valid base64url string (43 characters minimum)
|
||||
code_verifier = "a" * 43 + "B1-_" * 10
|
||||
|
||||
# Act
|
||||
request = TokenExchangeRequest(code_verifier=code_verifier)
|
||||
|
||||
# Assert
|
||||
assert request.code_verifier == code_verifier
|
||||
|
||||
def test_code_verifier_min_length(self):
|
||||
"""Test code_verifier minimum length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when code_verifier is less than 43 chars
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TokenExchangeRequest(code_verifier="a" * 42)
|
||||
|
||||
assert "String should have at least 43 characters" in str(exc_info.value)
|
||||
|
||||
def test_code_verifier_max_length(self):
|
||||
"""Test code_verifier maximum length validation.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when code_verifier exceeds 128 chars
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TokenExchangeRequest(code_verifier="a" * 129)
|
||||
|
||||
assert "String should have at most 128 characters" in str(exc_info.value)
|
||||
|
||||
def test_code_verifier_format_validation_valid(self):
|
||||
"""Test code_verifier format validation accepts valid base64url.
|
||||
|
||||
Asserts:
|
||||
- Valid base64url characters (A-Z, a-z, 0-9, -, _) are accepted
|
||||
"""
|
||||
# Arrange
|
||||
valid_verifier = (
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
)
|
||||
|
||||
# Act
|
||||
request = TokenExchangeRequest(code_verifier=valid_verifier)
|
||||
|
||||
# Assert
|
||||
assert request.code_verifier == valid_verifier
|
||||
|
||||
def test_code_verifier_format_validation_invalid(self):
|
||||
"""Test code_verifier format validation rejects invalid characters.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised for non-base64url characters
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TokenExchangeRequest(
|
||||
code_verifier="a" * 43 + "!"
|
||||
) # ! is not valid base64url
|
||||
|
||||
assert "code_verifier must be valid base64url" in str(exc_info.value)
|
||||
|
||||
def test_code_verifier_format_validation_special_chars(self):
|
||||
"""Test code_verifier format validation rejects special characters.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised for special characters like +, =
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TokenExchangeRequest(code_verifier="a" * 43 + "+")
|
||||
|
||||
assert "code_verifier must be valid base64url" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestTokenExchangeResponse:
|
||||
"""Test suite for TokenExchangeResponse schema."""
|
||||
|
||||
def test_valid_token_exchange_response(self):
|
||||
"""Test creating a valid TokenExchangeResponse instance.
|
||||
|
||||
Asserts:
|
||||
- Instance is created successfully with all required fields
|
||||
- Default values are applied
|
||||
"""
|
||||
# Arrange & Act
|
||||
response = TokenExchangeResponse(
|
||||
session_id="session-123",
|
||||
access_token="access-token-xyz",
|
||||
refresh_token="refresh-token-abc",
|
||||
csrf_token="csrf-token-def",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.session_id == "session-123"
|
||||
assert response.access_token == "access-token-xyz"
|
||||
assert response.refresh_token == "refresh-token-abc"
|
||||
assert response.csrf_token == "csrf-token-def"
|
||||
assert response.expires_in == 900 # Default value
|
||||
assert response.token_type == "Bearer" # Default value
|
||||
|
||||
def test_token_exchange_response_required_fields(self):
|
||||
"""Test required fields for TokenExchangeResponse.
|
||||
|
||||
Asserts:
|
||||
- ValidationError is raised when required fields are missing
|
||||
"""
|
||||
# Arrange & Act & Assert
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TokenExchangeResponse()
|
||||
|
||||
error_str = str(exc_info.value)
|
||||
assert "session_id" in error_str or "access_token" in error_str
|
||||
|
||||
def test_token_exchange_response_custom_expires_in(self):
|
||||
"""Test custom expires_in value for TokenExchangeResponse.
|
||||
|
||||
Asserts:
|
||||
- Custom expires_in value can be set
|
||||
"""
|
||||
# Arrange & Act
|
||||
response = TokenExchangeResponse(
|
||||
session_id="session-123",
|
||||
access_token="access-token-xyz",
|
||||
refresh_token="refresh-token-abc",
|
||||
csrf_token="csrf-token-def",
|
||||
expires_in=1800,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.expires_in == 1800
|
||||
|
||||
def test_token_exchange_response_custom_token_type(self):
|
||||
"""Test custom token_type value for TokenExchangeResponse.
|
||||
|
||||
Asserts:
|
||||
- Custom token_type value can be set
|
||||
"""
|
||||
# Arrange & Act
|
||||
response = TokenExchangeResponse(
|
||||
session_id="session-123",
|
||||
access_token="access-token-xyz",
|
||||
refresh_token="refresh-token-abc",
|
||||
csrf_token="csrf-token-def",
|
||||
token_type="Custom",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.token_type == "Custom"
|
||||
493
backend/tests/auth/identity_providers/test_utils.py
Normal file
493
backend/tests/auth/identity_providers/test_utils.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""Tests for identity_providers.utils module."""
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from auth.identity_providers.utils import (
|
||||
validate_pkce_challenge,
|
||||
validate_pkce_verifier,
|
||||
_secure_compare,
|
||||
get_idp_template,
|
||||
get_idp_templates,
|
||||
)
|
||||
from auth.identity_providers.schema import IdentityProviderTemplate
|
||||
|
||||
|
||||
class TestValidatePkceChallenge:
|
||||
"""Test suite for validate_pkce_challenge function."""
|
||||
|
||||
def test_validate_pkce_challenge_success(self):
|
||||
"""Test validating a valid PKCE challenge.
|
||||
|
||||
Asserts:
|
||||
- Valid S256 challenge passes validation
|
||||
"""
|
||||
# Arrange
|
||||
# Valid base64url-encoded SHA256 hash (43 chars minimum)
|
||||
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert (no exception should be raised)
|
||||
validate_pkce_challenge(code_challenge, code_challenge_method)
|
||||
|
||||
def test_validate_pkce_challenge_plain_method_rejected(self):
|
||||
"""Test PKCE plain method is rejected.
|
||||
|
||||
Asserts:
|
||||
- HTTPException with 400 status is raised for 'plain' method
|
||||
"""
|
||||
# Arrange
|
||||
code_challenge = "test_challenge"
|
||||
code_challenge_method = "plain"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_challenge(code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Only S256 PKCE method is supported" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_challenge_min_length_violation(self):
|
||||
"""Test PKCE challenge minimum length validation.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for challenge shorter than 43 chars
|
||||
"""
|
||||
# Arrange
|
||||
code_challenge = "a" * 42 # Too short
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_challenge(code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "43-128 characters" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_challenge_max_length_violation(self):
|
||||
"""Test PKCE challenge maximum length validation.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for challenge longer than 128 chars
|
||||
"""
|
||||
# Arrange
|
||||
code_challenge = "a" * 129 # Too long
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_challenge(code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "43-128 characters" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_challenge_invalid_characters(self):
|
||||
"""Test PKCE challenge with invalid characters.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for non-base64url characters
|
||||
"""
|
||||
# Arrange
|
||||
code_challenge = "a" * 43 + "!@#$" # Invalid characters
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_challenge(code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "valid base64url" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_challenge_valid_base64url_characters(self):
|
||||
"""Test PKCE challenge with all valid base64url characters.
|
||||
|
||||
Asserts:
|
||||
- All base64url characters (A-Z, a-z, 0-9, -, _) are accepted
|
||||
"""
|
||||
# Arrange
|
||||
code_challenge = (
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr_-" # 45 chars, valid
|
||||
)
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert (no exception should be raised)
|
||||
validate_pkce_challenge(code_challenge, code_challenge_method)
|
||||
|
||||
|
||||
class TestValidatePkceVerifier:
|
||||
"""Test suite for validate_pkce_verifier function."""
|
||||
|
||||
def test_validate_pkce_verifier_success(self):
|
||||
"""Test validating a valid PKCE verifier.
|
||||
|
||||
Asserts:
|
||||
- Valid verifier that matches challenge passes validation
|
||||
"""
|
||||
# Arrange
|
||||
# Valid verifier (RFC 7636 compliant)
|
||||
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
# Corresponding S256 challenge
|
||||
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert (no exception should be raised)
|
||||
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
|
||||
|
||||
def test_validate_pkce_verifier_min_length_violation(self):
|
||||
"""Test PKCE verifier minimum length validation.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for verifier shorter than 43 chars
|
||||
"""
|
||||
# Arrange
|
||||
code_verifier = "a" * 42 # Too short
|
||||
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "43-128 characters" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_verifier_max_length_violation(self):
|
||||
"""Test PKCE verifier maximum length validation.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for verifier longer than 128 chars
|
||||
"""
|
||||
# Arrange
|
||||
code_verifier = "a" * 129 # Too long
|
||||
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "43-128 characters" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_verifier_invalid_characters(self):
|
||||
"""Test PKCE verifier with invalid characters.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for non-base64url characters
|
||||
"""
|
||||
# Arrange
|
||||
code_verifier = "a" * 43 + "!@#" # Invalid characters
|
||||
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "valid base64url" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_verifier_wrong_method(self):
|
||||
"""Test PKCE verifier with wrong challenge method.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for non-S256 method
|
||||
"""
|
||||
# Arrange
|
||||
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
code_challenge_method = "plain"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Only S256 PKCE method is supported" in str(exc_info.value.detail)
|
||||
|
||||
def test_validate_pkce_verifier_mismatch(self):
|
||||
"""Test PKCE verifier that doesn't match challenge.
|
||||
|
||||
Asserts:
|
||||
- HTTPException is raised for mismatched verifier
|
||||
"""
|
||||
# Arrange
|
||||
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
code_challenge = "wrong_challenge_value_that_does_not_match_verifier"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Invalid code_verifier" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
class TestSecureCompare:
|
||||
"""Test suite for _secure_compare function."""
|
||||
|
||||
def test_secure_compare_equal_strings(self):
|
||||
"""Test secure comparison of equal strings.
|
||||
|
||||
Asserts:
|
||||
- Returns True for identical strings
|
||||
"""
|
||||
# Arrange
|
||||
str1 = "test_string_123"
|
||||
str2 = "test_string_123"
|
||||
|
||||
# Act
|
||||
result = _secure_compare(str1, str2)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_secure_compare_different_strings(self):
|
||||
"""Test secure comparison of different strings.
|
||||
|
||||
Asserts:
|
||||
- Returns False for different strings
|
||||
"""
|
||||
# Arrange
|
||||
str1 = "test_string_123"
|
||||
str2 = "test_string_456"
|
||||
|
||||
# Act
|
||||
result = _secure_compare(str1, str2)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_secure_compare_different_lengths(self):
|
||||
"""Test secure comparison of strings with different lengths.
|
||||
|
||||
Asserts:
|
||||
- Returns False for strings of different lengths
|
||||
"""
|
||||
# Arrange
|
||||
str1 = "short"
|
||||
str2 = "much_longer_string"
|
||||
|
||||
# Act
|
||||
result = _secure_compare(str1, str2)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_secure_compare_empty_strings(self):
|
||||
"""Test secure comparison of empty strings.
|
||||
|
||||
Asserts:
|
||||
- Returns True for both empty strings
|
||||
"""
|
||||
# Arrange
|
||||
str1 = ""
|
||||
str2 = ""
|
||||
|
||||
# Act
|
||||
result = _secure_compare(str1, str2)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_secure_compare_one_empty_string(self):
|
||||
"""Test secure comparison with one empty string.
|
||||
|
||||
Asserts:
|
||||
- Returns False when one string is empty
|
||||
"""
|
||||
# Arrange
|
||||
str1 = "non_empty"
|
||||
str2 = ""
|
||||
|
||||
# Act
|
||||
result = _secure_compare(str1, str2)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_secure_compare_case_sensitive(self):
|
||||
"""Test secure comparison is case-sensitive.
|
||||
|
||||
Asserts:
|
||||
- Returns False for strings differing only in case
|
||||
"""
|
||||
# Arrange
|
||||
str1 = "TestString"
|
||||
str2 = "teststring"
|
||||
|
||||
# Act
|
||||
result = _secure_compare(str1, str2)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetIdpTemplate:
|
||||
"""Test suite for get_idp_template function."""
|
||||
|
||||
def test_get_idp_template_keycloak(self):
|
||||
"""Test retrieving Keycloak IdP template.
|
||||
|
||||
Asserts:
|
||||
- Keycloak template is returned with correct structure
|
||||
"""
|
||||
# Act
|
||||
template = get_idp_template("keycloak")
|
||||
|
||||
# Assert
|
||||
assert template is not None
|
||||
assert template["name"] == "Keycloak"
|
||||
assert template["provider_type"] == "oidc"
|
||||
assert "issuer_url" in template
|
||||
assert "scopes" in template
|
||||
assert "user_mapping" in template
|
||||
|
||||
def test_get_idp_template_authentik(self):
|
||||
"""Test retrieving Authentik IdP template.
|
||||
|
||||
Asserts:
|
||||
- Authentik template is returned with correct structure
|
||||
"""
|
||||
# Act
|
||||
template = get_idp_template("authentik")
|
||||
|
||||
# Assert
|
||||
assert template is not None
|
||||
assert template["name"] == "Authentik"
|
||||
assert template["provider_type"] == "oidc"
|
||||
|
||||
def test_get_idp_template_authelia(self):
|
||||
"""Test retrieving Authelia IdP template.
|
||||
|
||||
Asserts:
|
||||
- Authelia template is returned with correct structure
|
||||
"""
|
||||
# Act
|
||||
template = get_idp_template("authelia")
|
||||
|
||||
# Assert
|
||||
assert template is not None
|
||||
assert template["name"] == "Authelia"
|
||||
assert template["provider_type"] == "oidc"
|
||||
|
||||
def test_get_idp_template_casdoor(self):
|
||||
"""Test retrieving Casdoor IdP template.
|
||||
|
||||
Asserts:
|
||||
- Casdoor template is returned with correct structure
|
||||
"""
|
||||
# Act
|
||||
template = get_idp_template("casdoor")
|
||||
|
||||
# Assert
|
||||
assert template is not None
|
||||
assert template["name"] == "Casdoor"
|
||||
assert template["provider_type"] == "oidc"
|
||||
|
||||
def test_get_idp_template_pocketid(self):
|
||||
"""Test retrieving Pocket ID template.
|
||||
|
||||
Asserts:
|
||||
- Pocket ID template is returned with correct structure
|
||||
"""
|
||||
# Act
|
||||
template = get_idp_template("pocketid")
|
||||
|
||||
# Assert
|
||||
assert template is not None
|
||||
assert template["name"] == "Pocket ID"
|
||||
assert template["provider_type"] == "oidc"
|
||||
|
||||
def test_get_idp_template_nonexistent(self):
|
||||
"""Test retrieving non-existent IdP template.
|
||||
|
||||
Asserts:
|
||||
- None is returned for non-existent template
|
||||
"""
|
||||
# Act
|
||||
template = get_idp_template("nonexistent_provider")
|
||||
|
||||
# Assert
|
||||
assert template is None
|
||||
|
||||
|
||||
class TestGetIdpTemplates:
|
||||
"""Test suite for get_idp_templates function."""
|
||||
|
||||
def test_get_idp_templates_returns_list(self):
|
||||
"""Test get_idp_templates returns a list.
|
||||
|
||||
Asserts:
|
||||
- Returns a list of IdentityProviderTemplate objects
|
||||
"""
|
||||
# Act
|
||||
templates = get_idp_templates()
|
||||
|
||||
# Assert
|
||||
assert isinstance(templates, list)
|
||||
assert len(templates) > 0
|
||||
assert all(isinstance(t, IdentityProviderTemplate) for t in templates)
|
||||
|
||||
def test_get_idp_templates_contains_expected_providers(self):
|
||||
"""Test get_idp_templates contains expected providers.
|
||||
|
||||
Asserts:
|
||||
- List contains Keycloak, Authentik, Authelia, Casdoor, Pocket ID
|
||||
"""
|
||||
# Act
|
||||
templates = get_idp_templates()
|
||||
|
||||
# Assert
|
||||
template_ids = [t.template_id for t in templates]
|
||||
assert "keycloak" in template_ids
|
||||
assert "authentik" in template_ids
|
||||
assert "authelia" in template_ids
|
||||
assert "casdoor" in template_ids
|
||||
assert "pocketid" in template_ids
|
||||
|
||||
def test_get_idp_templates_structure(self):
|
||||
"""Test each template has required structure.
|
||||
|
||||
Asserts:
|
||||
- Each template has required fields
|
||||
"""
|
||||
# Act
|
||||
templates = get_idp_templates()
|
||||
|
||||
# Assert
|
||||
for template in templates:
|
||||
assert hasattr(template, "template_id")
|
||||
assert hasattr(template, "name")
|
||||
assert hasattr(template, "provider_type")
|
||||
assert hasattr(template, "scopes")
|
||||
assert hasattr(template, "description")
|
||||
|
||||
def test_get_idp_templates_all_oidc(self):
|
||||
"""Test all templates are OIDC providers.
|
||||
|
||||
Asserts:
|
||||
- All templates have provider_type 'oidc'
|
||||
"""
|
||||
# Act
|
||||
templates = get_idp_templates()
|
||||
|
||||
# Assert
|
||||
for template in templates:
|
||||
assert template.provider_type == "oidc"
|
||||
|
||||
def test_get_idp_templates_has_user_mapping(self):
|
||||
"""Test templates include user mapping configuration.
|
||||
|
||||
Asserts:
|
||||
- Each template has user_mapping with username and email
|
||||
"""
|
||||
# Act
|
||||
templates = get_idp_templates()
|
||||
|
||||
# Assert
|
||||
for template in templates:
|
||||
assert template.user_mapping is not None
|
||||
assert "username" in template.user_mapping
|
||||
assert "email" in template.user_mapping
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
181
backend/tests/auth/test_constants.py
Normal file
181
backend/tests/auth/test_constants.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Tests for auth.constants module."""
|
||||
|
||||
import pytest
|
||||
|
||||
import auth.constants as auth_constants
|
||||
|
||||
|
||||
class TestJWTConstants:
|
||||
"""Test JWT configuration constants."""
|
||||
|
||||
def test_jwt_algorithm_is_set(self):
|
||||
"""Test that JWT algorithm is defined and is a valid algorithm."""
|
||||
assert auth_constants.JWT_ALGORITHM is not None
|
||||
assert isinstance(auth_constants.JWT_ALGORITHM, str)
|
||||
assert auth_constants.JWT_ALGORITHM in ["HS256", "HS384", "HS512"]
|
||||
|
||||
def test_access_token_expiry_is_positive(self):
|
||||
"""Test that access token expiry is a positive integer."""
|
||||
assert auth_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES > 0
|
||||
assert isinstance(auth_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES, int)
|
||||
|
||||
def test_refresh_token_expiry_is_positive(self):
|
||||
"""Test that refresh token expiry is a positive integer."""
|
||||
assert auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS > 0
|
||||
assert isinstance(auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS, int)
|
||||
|
||||
def test_secret_key_is_set(self):
|
||||
"""Test that SECRET_KEY is configured."""
|
||||
assert auth_constants.JWT_SECRET_KEY is not None
|
||||
assert isinstance(auth_constants.JWT_SECRET_KEY, str)
|
||||
assert len(auth_constants.JWT_SECRET_KEY) >= 32
|
||||
|
||||
def test_session_timeout_constants(self):
|
||||
"""Test that session timeout constants are valid."""
|
||||
assert isinstance(auth_constants.SESSION_IDLE_TIMEOUT_ENABLED, bool)
|
||||
assert auth_constants.SESSION_IDLE_TIMEOUT_HOURS >= 0
|
||||
assert isinstance(auth_constants.SESSION_IDLE_TIMEOUT_HOURS, int)
|
||||
assert auth_constants.SESSION_ABSOLUTE_TIMEOUT_HOURS > 0
|
||||
assert isinstance(auth_constants.SESSION_ABSOLUTE_TIMEOUT_HOURS, int)
|
||||
|
||||
|
||||
class TestScopeConstants:
|
||||
"""Test scope configuration constants."""
|
||||
|
||||
def test_users_regular_scope_defined(self):
|
||||
"""Test that users regular scope is properly defined."""
|
||||
assert auth_constants.USERS_REGULAR_SCOPE is not None
|
||||
assert isinstance(auth_constants.USERS_REGULAR_SCOPE, tuple)
|
||||
assert "profile" in auth_constants.USERS_REGULAR_SCOPE
|
||||
assert "users:read" in auth_constants.USERS_REGULAR_SCOPE
|
||||
|
||||
def test_users_admin_scope_defined(self):
|
||||
"""Test that users admin scope is properly defined."""
|
||||
assert auth_constants.USERS_ADMIN_SCOPE is not None
|
||||
assert isinstance(auth_constants.USERS_ADMIN_SCOPE, tuple)
|
||||
assert "users:write" in auth_constants.USERS_ADMIN_SCOPE
|
||||
assert "sessions:read" in auth_constants.USERS_ADMIN_SCOPE
|
||||
assert "sessions:write" in auth_constants.USERS_ADMIN_SCOPE
|
||||
|
||||
def test_gears_scope_defined(self):
|
||||
"""Test that gears scope is properly defined."""
|
||||
assert auth_constants.GEARS_SCOPE is not None
|
||||
assert isinstance(auth_constants.GEARS_SCOPE, tuple)
|
||||
assert "gears:read" in auth_constants.GEARS_SCOPE
|
||||
assert "gears:write" in auth_constants.GEARS_SCOPE
|
||||
|
||||
def test_activities_scope_defined(self):
|
||||
"""Test that activities scope is properly defined."""
|
||||
assert auth_constants.ACTIVITIES_SCOPE is not None
|
||||
assert isinstance(auth_constants.ACTIVITIES_SCOPE, tuple)
|
||||
assert "activities:read" in auth_constants.ACTIVITIES_SCOPE
|
||||
assert "activities:write" in auth_constants.ACTIVITIES_SCOPE
|
||||
|
||||
def test_health_scope_defined(self):
|
||||
"""Test that health scope is properly defined."""
|
||||
assert auth_constants.HEALTH_SCOPE is not None
|
||||
assert isinstance(auth_constants.HEALTH_SCOPE, tuple)
|
||||
assert "health:read" in auth_constants.HEALTH_SCOPE
|
||||
assert "health:write" in auth_constants.HEALTH_SCOPE
|
||||
assert "health_targets:read" in auth_constants.HEALTH_SCOPE
|
||||
assert "health_targets:write" in auth_constants.HEALTH_SCOPE
|
||||
|
||||
def test_identity_providers_scope_defined(self):
|
||||
"""Test that identity providers scopes are properly defined."""
|
||||
assert auth_constants.IDENTITY_PROVIDERS_REGULAR_SCOPE is not None
|
||||
assert isinstance(auth_constants.IDENTITY_PROVIDERS_REGULAR_SCOPE, tuple)
|
||||
assert (
|
||||
"identity_providers:read" in auth_constants.IDENTITY_PROVIDERS_REGULAR_SCOPE
|
||||
)
|
||||
|
||||
assert auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE is not None
|
||||
assert isinstance(auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE, tuple)
|
||||
assert (
|
||||
"identity_providers:write" in auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE
|
||||
)
|
||||
|
||||
def test_server_settings_scope_defined(self):
|
||||
"""Test that server settings scopes are properly defined."""
|
||||
assert auth_constants.SERVER_SETTINGS_REGULAR_SCOPE is not None
|
||||
assert isinstance(auth_constants.SERVER_SETTINGS_REGULAR_SCOPE, tuple)
|
||||
|
||||
assert auth_constants.SERVER_SETTINGS_ADMIN_SCOPE is not None
|
||||
assert isinstance(auth_constants.SERVER_SETTINGS_ADMIN_SCOPE, tuple)
|
||||
assert "server_settings:read" in auth_constants.SERVER_SETTINGS_ADMIN_SCOPE
|
||||
assert "server_settings:write" in auth_constants.SERVER_SETTINGS_ADMIN_SCOPE
|
||||
|
||||
def test_scope_dict_contains_all_scopes(self):
|
||||
"""Test that SCOPE_DICT contains descriptions for all defined scopes."""
|
||||
assert auth_constants.SCOPE_DICT is not None
|
||||
assert isinstance(auth_constants.SCOPE_DICT, dict)
|
||||
|
||||
# Check key scopes are documented
|
||||
expected_scopes = [
|
||||
"profile",
|
||||
"users:read",
|
||||
"users:write",
|
||||
"gears:read",
|
||||
"gears:write",
|
||||
"activities:read",
|
||||
"activities:write",
|
||||
"health:read",
|
||||
"health:write",
|
||||
"server_settings:read",
|
||||
"server_settings:write",
|
||||
]
|
||||
|
||||
for scope in expected_scopes:
|
||||
assert scope in auth_constants.SCOPE_DICT
|
||||
assert isinstance(auth_constants.SCOPE_DICT[scope], str)
|
||||
assert len(auth_constants.SCOPE_DICT[scope]) > 0
|
||||
|
||||
def test_regular_access_scope_composition(self):
|
||||
"""Test that REGULAR_ACCESS_SCOPE includes all expected regular scopes."""
|
||||
assert auth_constants.REGULAR_ACCESS_SCOPE is not None
|
||||
assert isinstance(auth_constants.REGULAR_ACCESS_SCOPE, tuple)
|
||||
|
||||
# Should include users regular scope
|
||||
for scope in auth_constants.USERS_REGULAR_SCOPE:
|
||||
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
|
||||
|
||||
# Should include gears scope
|
||||
for scope in auth_constants.GEARS_SCOPE:
|
||||
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
|
||||
|
||||
# Should include activities scope
|
||||
for scope in auth_constants.ACTIVITIES_SCOPE:
|
||||
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
|
||||
|
||||
# Should include health scope
|
||||
for scope in auth_constants.HEALTH_SCOPE:
|
||||
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
|
||||
|
||||
def test_admin_access_scope_composition(self):
|
||||
"""Test that ADMIN_ACCESS_SCOPE includes all regular and admin scopes."""
|
||||
assert auth_constants.ADMIN_ACCESS_SCOPE is not None
|
||||
assert isinstance(auth_constants.ADMIN_ACCESS_SCOPE, tuple)
|
||||
|
||||
# Should include all regular scopes
|
||||
for scope in auth_constants.REGULAR_ACCESS_SCOPE:
|
||||
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
|
||||
|
||||
# Should include users admin scope
|
||||
for scope in auth_constants.USERS_ADMIN_SCOPE:
|
||||
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
|
||||
|
||||
# Should include identity providers admin scope
|
||||
for scope in auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE:
|
||||
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
|
||||
|
||||
# Should include server settings admin scope
|
||||
for scope in auth_constants.SERVER_SETTINGS_ADMIN_SCOPE:
|
||||
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
|
||||
|
||||
def test_admin_scope_is_superset_of_regular(self):
|
||||
"""Test that admin scope contains all permissions from regular scope."""
|
||||
regular_set = set(auth_constants.REGULAR_ACCESS_SCOPE)
|
||||
admin_set = set(auth_constants.ADMIN_ACCESS_SCOPE)
|
||||
|
||||
assert regular_set.issubset(
|
||||
admin_set
|
||||
), "Admin scope should contain all regular scope permissions"
|
||||
311
backend/tests/auth/test_schema.py
Normal file
311
backend/tests/auth/test_schema.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Tests for auth.schema module.
|
||||
|
||||
This module tests Pydantic schemas and dependency classes for authentication,
|
||||
including login requests, MFA management, and failed attempt tracking.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pydantic import ValidationError
|
||||
|
||||
import auth.schema as auth_schema
|
||||
|
||||
|
||||
class TestLoginRequest:
|
||||
"""Tests for LoginRequest Pydantic model."""
|
||||
|
||||
def test_login_request_valid(self):
|
||||
"""Test valid login request."""
|
||||
request = auth_schema.LoginRequest(username="testuser", password="Password1!")
|
||||
assert request.username == "testuser"
|
||||
assert request.password == "Password1!"
|
||||
|
||||
def test_login_request_username_too_short(self):
|
||||
"""Test login request with empty username."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
auth_schema.LoginRequest(username="", password="Password1!")
|
||||
assert "username" in str(exc_info.value)
|
||||
|
||||
def test_login_request_username_too_long(self):
|
||||
"""Test login request with username exceeding max length."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
auth_schema.LoginRequest(username="a" * 251, password="Password1!")
|
||||
assert "username" in str(exc_info.value)
|
||||
|
||||
def test_login_request_password_too_short(self):
|
||||
"""Test login request with password less than 8 characters."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
auth_schema.LoginRequest(username="testuser", password="Pass1!")
|
||||
assert "password" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestMFALoginRequest:
|
||||
"""Tests for MFALoginRequest Pydantic model."""
|
||||
|
||||
def test_mfa_login_request_valid(self):
|
||||
"""Test valid MFA login request with 6-digit code."""
|
||||
request = auth_schema.MFALoginRequest(username="testuser", mfa_code="123456")
|
||||
assert request.username == "testuser"
|
||||
assert request.mfa_code == "123456"
|
||||
|
||||
def test_mfa_login_request_invalid_code_format_letters(self):
|
||||
"""Test MFA login request with non-numeric code."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
auth_schema.MFALoginRequest(username="testuser", mfa_code="12345a")
|
||||
assert "mfa_code" in str(exc_info.value)
|
||||
|
||||
def test_mfa_login_request_invalid_code_too_short(self):
|
||||
"""Test MFA login request with code less than 6 digits."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
auth_schema.MFALoginRequest(username="testuser", mfa_code="12345")
|
||||
assert "mfa_code" in str(exc_info.value)
|
||||
|
||||
def test_mfa_login_request_invalid_code_too_long(self):
|
||||
"""Test MFA login request with code more than 6 digits."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
auth_schema.MFALoginRequest(username="testuser", mfa_code="1234567")
|
||||
assert "mfa_code" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestMFARequiredResponse:
|
||||
"""Tests for MFARequiredResponse Pydantic model."""
|
||||
|
||||
def test_mfa_required_response_defaults(self):
|
||||
"""Test MFA required response with default values."""
|
||||
response = auth_schema.MFARequiredResponse(username="testuser")
|
||||
assert response.mfa_required is True
|
||||
assert response.username == "testuser"
|
||||
assert response.message == "MFA verification required"
|
||||
|
||||
def test_mfa_required_response_custom_message(self):
|
||||
"""Test MFA required response with custom message."""
|
||||
response = auth_schema.MFARequiredResponse(
|
||||
username="testuser", message="Custom MFA message"
|
||||
)
|
||||
assert response.mfa_required is True
|
||||
assert response.message == "Custom MFA message"
|
||||
|
||||
def test_mfa_required_response_explicit_false(self):
|
||||
"""Test MFA required response with explicit False."""
|
||||
response = auth_schema.MFARequiredResponse(
|
||||
mfa_required=False, username="testuser"
|
||||
)
|
||||
assert response.mfa_required is False
|
||||
|
||||
|
||||
class TestPendingMFALogin:
|
||||
"""Tests for PendingMFALogin class."""
|
||||
|
||||
def test_add_and_get_pending_login(self):
|
||||
"""Test adding and retrieving pending MFA login."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.add_pending_login("testuser", 123)
|
||||
assert store.get_pending_login("testuser") == 123
|
||||
|
||||
def test_get_pending_login_not_found(self):
|
||||
"""Test getting non-existent pending login returns None."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
assert store.get_pending_login("nonexistent") is None
|
||||
|
||||
def test_has_pending_login(self):
|
||||
"""Test checking if username has pending login."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.add_pending_login("testuser", 123)
|
||||
assert store.has_pending_login("testuser") is True
|
||||
assert store.has_pending_login("nonexistent") is False
|
||||
|
||||
def test_delete_pending_login(self):
|
||||
"""Test deleting pending login."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.add_pending_login("testuser", 123)
|
||||
store.delete_pending_login("testuser")
|
||||
assert store.get_pending_login("testuser") is None
|
||||
|
||||
def test_delete_nonexistent_pending_login(self):
|
||||
"""Test deleting non-existent pending login doesn't raise error."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.delete_pending_login("nonexistent") # Should not raise
|
||||
|
||||
def test_clear_all(self):
|
||||
"""Test clearing all pending logins."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.add_pending_login("user1", 1)
|
||||
store.add_pending_login("user2", 2)
|
||||
store.clear_all()
|
||||
assert store.get_pending_login("user1") is None
|
||||
assert store.get_pending_login("user2") is None
|
||||
|
||||
def test_is_not_locked_out_initially(self):
|
||||
"""Test user is not locked out initially."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
assert store.is_locked_out("testuser") is False
|
||||
|
||||
def test_lockout_after_5_failures(self):
|
||||
"""Test 5-minute lockout after 5 failed attempts."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
for _ in range(5):
|
||||
store.record_failed_attempt("testuser")
|
||||
assert store.is_locked_out("testuser") is True
|
||||
lockout_time = store.get_lockout_time("testuser")
|
||||
assert lockout_time is not None
|
||||
assert lockout_time > datetime.now(timezone.utc)
|
||||
|
||||
def test_lockout_after_10_failures(self):
|
||||
"""Test 30-minute lockout after 10 failed attempts."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
for _ in range(10):
|
||||
store.record_failed_attempt("testuser")
|
||||
assert store.is_locked_out("testuser") is True
|
||||
|
||||
def test_lockout_after_15_failures(self):
|
||||
"""Test 2-hour lockout after 15 failed attempts."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
for _ in range(15):
|
||||
store.record_failed_attempt("testuser")
|
||||
assert store.is_locked_out("testuser") is True
|
||||
|
||||
def test_failed_attempt_count_doesnt_increment_while_locked(self):
|
||||
"""Test failed attempt counter doesn't increment during lockout."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
for _ in range(5):
|
||||
store.record_failed_attempt("testuser")
|
||||
# Try to increment during lockout
|
||||
count_before = store.record_failed_attempt("testuser")
|
||||
count_after = store.record_failed_attempt("testuser")
|
||||
assert count_before == count_after
|
||||
|
||||
def test_reset_failed_attempts(self):
|
||||
"""Test resetting failed attempts on successful verification."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.record_failed_attempt("testuser")
|
||||
store.record_failed_attempt("testuser")
|
||||
store.reset_failed_attempts("testuser")
|
||||
assert store.is_locked_out("testuser") is False
|
||||
assert store.get_lockout_time("testuser") is None
|
||||
|
||||
def test_get_lockout_time_returns_none_when_not_locked(self):
|
||||
"""Test get_lockout_time returns None when user not locked."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
assert store.get_lockout_time("testuser") is None
|
||||
|
||||
def test_clear_all_clears_failed_attempts(self):
|
||||
"""Test clear_all() clears both pending logins and failed attempts."""
|
||||
store = auth_schema.PendingMFALogin()
|
||||
store.add_pending_login("testuser", 123)
|
||||
for _ in range(5):
|
||||
store.record_failed_attempt("testuser")
|
||||
store.clear_all()
|
||||
assert store.is_locked_out("testuser") is False
|
||||
|
||||
|
||||
class TestFailedLoginAttempts:
|
||||
"""Tests for FailedLoginAttempts class."""
|
||||
|
||||
def test_is_not_locked_out_initially(self):
|
||||
"""Test user is not locked out initially."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
assert tracker.is_locked_out("testuser") is False
|
||||
|
||||
def test_lockout_after_5_failures(self):
|
||||
"""Test 5-minute lockout after 5 failed login attempts."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
for _ in range(5):
|
||||
tracker.record_failed_attempt("testuser")
|
||||
assert tracker.is_locked_out("testuser") is True
|
||||
lockout_time = tracker.get_lockout_time("testuser")
|
||||
assert lockout_time is not None
|
||||
assert lockout_time > datetime.now(timezone.utc)
|
||||
|
||||
def test_lockout_after_10_failures(self):
|
||||
"""Test 30-minute lockout after 10 failed login attempts."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
for _ in range(10):
|
||||
tracker.record_failed_attempt("testuser")
|
||||
assert tracker.is_locked_out("testuser") is True
|
||||
|
||||
def test_lockout_after_20_failures(self):
|
||||
"""Test 24-hour lockout after 20 failed login attempts."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
for _ in range(20):
|
||||
tracker.record_failed_attempt("testuser")
|
||||
assert tracker.is_locked_out("testuser") is True
|
||||
|
||||
def test_failed_attempt_count_returns_correctly(self):
|
||||
"""Test record_failed_attempt returns current count."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
count1 = tracker.record_failed_attempt("testuser")
|
||||
count2 = tracker.record_failed_attempt("testuser")
|
||||
count3 = tracker.record_failed_attempt("testuser")
|
||||
assert count1 == 1
|
||||
assert count2 == 2
|
||||
assert count3 == 3
|
||||
|
||||
def test_failed_attempt_count_doesnt_increment_while_locked(self):
|
||||
"""Test failed attempt counter doesn't increment during lockout."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
for _ in range(5):
|
||||
tracker.record_failed_attempt("testuser")
|
||||
# Try to increment during lockout
|
||||
count_before = tracker.record_failed_attempt("testuser")
|
||||
count_after = tracker.record_failed_attempt("testuser")
|
||||
assert count_before == count_after
|
||||
|
||||
def test_reset_attempts(self):
|
||||
"""Test resetting failed attempts on successful login."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
tracker.record_failed_attempt("testuser")
|
||||
tracker.record_failed_attempt("testuser")
|
||||
tracker.reset_attempts("testuser")
|
||||
assert tracker.is_locked_out("testuser") is False
|
||||
assert tracker.get_lockout_time("testuser") is None
|
||||
|
||||
def test_get_lockout_time_returns_none_when_not_locked(self):
|
||||
"""Test get_lockout_time returns None when user not locked."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
assert tracker.get_lockout_time("testuser") is None
|
||||
|
||||
def test_clear_all(self):
|
||||
"""Test clearing all failed attempt records."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
tracker.record_failed_attempt("user1")
|
||||
tracker.record_failed_attempt("user2")
|
||||
tracker.clear_all()
|
||||
assert tracker.is_locked_out("user1") is False
|
||||
assert tracker.is_locked_out("user2") is False
|
||||
|
||||
def test_different_users_tracked_independently(self):
|
||||
"""Test different users have independent failed attempt tracking."""
|
||||
tracker = auth_schema.FailedLoginAttempts()
|
||||
for _ in range(3):
|
||||
tracker.record_failed_attempt("user1")
|
||||
for _ in range(5):
|
||||
tracker.record_failed_attempt("user2")
|
||||
assert tracker.is_locked_out("user1") is False
|
||||
assert tracker.is_locked_out("user2") is True
|
||||
|
||||
|
||||
class TestDependencyFunctions:
|
||||
"""Tests for dependency injection functions."""
|
||||
|
||||
def test_get_pending_mfa_store(self):
|
||||
"""Test get_pending_mfa_store returns PendingMFALogin instance."""
|
||||
store = auth_schema.get_pending_mfa_store()
|
||||
assert isinstance(store, auth_schema.PendingMFALogin)
|
||||
|
||||
def test_get_pending_mfa_store_returns_singleton(self):
|
||||
"""Test get_pending_mfa_store returns same instance."""
|
||||
store1 = auth_schema.get_pending_mfa_store()
|
||||
store2 = auth_schema.get_pending_mfa_store()
|
||||
assert store1 is store2
|
||||
|
||||
def test_get_failed_login_attempts(self):
|
||||
"""Test get_failed_login_attempts returns FailedLoginAttempts instance."""
|
||||
tracker = auth_schema.get_failed_login_attempts()
|
||||
assert isinstance(tracker, auth_schema.FailedLoginAttempts)
|
||||
|
||||
def test_get_failed_login_attempts_returns_singleton(self):
|
||||
"""Test get_failed_login_attempts returns same instance."""
|
||||
tracker1 = auth_schema.get_failed_login_attempts()
|
||||
tracker2 = auth_schema.get_failed_login_attempts()
|
||||
assert tracker1 is tracker2
|
||||
287
backend/tests/auth/test_security.py
Normal file
287
backend/tests/auth/test_security.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Tests for auth.security module."""
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import SecurityScopes
|
||||
|
||||
import auth.security as auth_security
|
||||
import auth.token_manager as auth_token_manager
|
||||
|
||||
|
||||
class TestGetToken:
|
||||
"""Test get_token function for token retrieval logic."""
|
||||
|
||||
def test_get_access_token_from_header(self):
|
||||
"""Test access token retrieval from Authorization header."""
|
||||
result = auth_security.get_token(
|
||||
non_cookie_token="test_token",
|
||||
cookie_token=None,
|
||||
client_type="web",
|
||||
token_type="access",
|
||||
)
|
||||
assert result == "test_token"
|
||||
|
||||
def test_get_access_token_missing_raises_error(self):
|
||||
"""Test that missing access token raises 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.get_token(
|
||||
non_cookie_token=None,
|
||||
cookie_token=None,
|
||||
client_type="web",
|
||||
token_type="access",
|
||||
)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Access token missing" in exc_info.value.detail
|
||||
|
||||
def test_get_refresh_token_from_cookie_for_web(self):
|
||||
"""Test refresh token retrieval from cookie for web client."""
|
||||
result = auth_security.get_token(
|
||||
non_cookie_token=None,
|
||||
cookie_token="refresh_cookie_token",
|
||||
client_type="web",
|
||||
token_type="refresh",
|
||||
)
|
||||
assert result == "refresh_cookie_token"
|
||||
|
||||
def test_get_refresh_token_from_header_for_mobile(self):
|
||||
"""Test refresh token retrieval from header for mobile client."""
|
||||
result = auth_security.get_token(
|
||||
non_cookie_token="refresh_header_token",
|
||||
cookie_token=None,
|
||||
client_type="mobile",
|
||||
token_type="refresh",
|
||||
)
|
||||
assert result == "refresh_header_token"
|
||||
|
||||
def test_get_refresh_token_missing_for_web_raises_error(self):
|
||||
"""Test that missing refresh token from cookie for web raises 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.get_token(
|
||||
non_cookie_token=None,
|
||||
cookie_token=None,
|
||||
client_type="web",
|
||||
token_type="refresh",
|
||||
)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Refresh token missing from cookie" in exc_info.value.detail
|
||||
|
||||
def test_get_refresh_token_missing_for_mobile_raises_error(self):
|
||||
"""Test that missing refresh token from header for mobile raises 401."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.get_token(
|
||||
non_cookie_token=None,
|
||||
cookie_token=None,
|
||||
client_type="mobile",
|
||||
token_type="refresh",
|
||||
)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert (
|
||||
"Refresh token missing from Authorization header" in exc_info.value.detail
|
||||
)
|
||||
|
||||
def test_invalid_token_type_raises_error(self):
|
||||
"""Test that invalid token type raises 403."""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.get_token(
|
||||
non_cookie_token="test_token",
|
||||
cookie_token=None,
|
||||
client_type="web",
|
||||
token_type="invalid_type",
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Invalid client type or token type" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestAccessTokenValidation:
|
||||
"""Test access token validation functions."""
|
||||
|
||||
def test_validate_access_token_success(self, token_manager, sample_user_read):
|
||||
"""Test successful access token validation."""
|
||||
# Create a valid token
|
||||
_, access_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
try:
|
||||
auth_security.validate_access_token(access_token, token_manager)
|
||||
except HTTPException:
|
||||
pytest.fail("Valid token should not raise HTTPException")
|
||||
|
||||
def test_validate_access_token_with_expired_token(self, token_manager):
|
||||
"""Test that expired token raises HTTPException."""
|
||||
expired_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzaWQiOiJzZXNzaW9uLWlkIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwiYXVkIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwic3ViIjoxLCJzY29wZSI6WyJwcm9maWxlIl0sImlhdCI6MTc1OTk1MzE4NSwibmJmIjoxNzU5OTUzMTg1LCJleHAiOjE3NTk5NTQwODUsImp0aSI6Ijc5ZjY0MmVkLTQ3M2QtNDEwZi1hYzI1LTIyNjEwNTlhMzg2MiJ9.VSizGzvIIi_EJYD_YmfZBEBE_9aJbhLW-25cD1kEOeM"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.validate_access_token(expired_token, token_manager)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_validate_access_token_with_invalid_token(self, token_manager):
|
||||
"""Test that invalid token raises HTTPException."""
|
||||
invalid_token = "invalid.token.here"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.validate_access_token(invalid_token, token_manager)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestGetSubFromAccessToken:
|
||||
"""Test extracting user ID from access token."""
|
||||
|
||||
def test_get_sub_from_valid_token(self, token_manager, sample_user_read):
|
||||
"""Test extracting user ID from valid access token."""
|
||||
_, access_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
sub = auth_security.get_sub_from_access_token(access_token, token_manager)
|
||||
assert sub == sample_user_read.id
|
||||
assert isinstance(sub, int)
|
||||
|
||||
def test_get_sub_from_invalid_token_raises_error(self, token_manager):
|
||||
"""Test that invalid token raises HTTPException."""
|
||||
invalid_token = "invalid.token.here"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.get_sub_from_access_token(invalid_token, token_manager)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestGetSidFromAccessToken:
|
||||
"""Test extracting session ID from access token."""
|
||||
|
||||
def test_get_sid_from_valid_token(self, token_manager, sample_user_read):
|
||||
"""Test extracting session ID from valid access token."""
|
||||
session_id = "test-session-123"
|
||||
_, access_token = token_manager.create_token(
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
sid = auth_security.get_sid_from_access_token(access_token, token_manager)
|
||||
assert sid == session_id
|
||||
assert isinstance(sid, str)
|
||||
|
||||
def test_get_sid_from_invalid_token_raises_error(self, token_manager):
|
||||
"""Test that invalid token raises HTTPException."""
|
||||
invalid_token = "invalid.token.here"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.get_sid_from_access_token(invalid_token, token_manager)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestRefreshTokenValidation:
|
||||
"""Test refresh token validation functions."""
|
||||
|
||||
def test_validate_refresh_token_success(self, token_manager, sample_user_read):
|
||||
"""Test successful refresh token validation."""
|
||||
_, refresh_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
try:
|
||||
auth_security.validate_refresh_token(refresh_token, token_manager)
|
||||
except HTTPException:
|
||||
pytest.fail("Valid refresh token should not raise HTTPException")
|
||||
|
||||
def test_validate_refresh_token_with_expired_token(self, token_manager):
|
||||
"""Test that expired refresh token raises HTTPException."""
|
||||
expired_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzaWQiOiJzZXNzaW9uLWlkIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwiYXVkIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwic3ViIjoxLCJzY29wZSI6WyJwcm9maWxlIl0sImlhdCI6MTc1OTk1MzE4NSwibmJmIjoxNzU5OTUzMTg1LCJleHAiOjE3NTk5NTQwODUsImp0aSI6Ijc5ZjY0MmVkLTQ3M2QtNDEwZi1hYzI1LTIyNjEwNTlhMzg2MiJ9.VSizGzvIIi_EJYD_YmfZBEBE_9aJbhLW-25cD1kEOeM"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.validate_refresh_token(expired_token, token_manager)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestGetSubFromRefreshToken:
|
||||
"""Test extracting user ID from refresh token."""
|
||||
|
||||
def test_get_sub_from_valid_refresh_token(self, token_manager, sample_user_read):
|
||||
"""Test extracting user ID from valid refresh token."""
|
||||
_, refresh_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
sub = auth_security.get_sub_from_refresh_token(refresh_token, token_manager)
|
||||
assert sub == sample_user_read.id
|
||||
assert isinstance(sub, int)
|
||||
|
||||
|
||||
class TestGetSidFromRefreshToken:
|
||||
"""Test extracting session ID from refresh token."""
|
||||
|
||||
def test_get_sid_from_valid_refresh_token(self, token_manager, sample_user_read):
|
||||
"""Test extracting session ID from valid refresh token."""
|
||||
session_id = "test-session-456"
|
||||
_, refresh_token = token_manager.create_token(
|
||||
session_id, sample_user_read, auth_token_manager.TokenType.REFRESH
|
||||
)
|
||||
|
||||
sid = auth_security.get_sid_from_refresh_token(refresh_token, token_manager)
|
||||
assert sid == session_id
|
||||
assert isinstance(sid, str)
|
||||
|
||||
|
||||
class TestCheckScopes:
|
||||
"""Test scope validation function."""
|
||||
|
||||
def test_check_scopes_with_valid_scopes(self, token_manager, sample_user_read):
|
||||
"""Test that valid scopes pass validation."""
|
||||
_, access_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
security_scopes = SecurityScopes(scopes=["profile", "users:read"])
|
||||
|
||||
# Should not raise an exception
|
||||
try:
|
||||
auth_security.check_scopes(access_token, token_manager, security_scopes)
|
||||
except HTTPException:
|
||||
pytest.fail("Valid scopes should not raise HTTPException")
|
||||
|
||||
def test_check_scopes_with_missing_scope(self, token_manager, sample_user_read):
|
||||
"""Test that missing required scope raises 403."""
|
||||
_, access_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
# Request a scope that the user doesn't have
|
||||
security_scopes = SecurityScopes(scopes=["admin:write"])
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_security.check_scopes(access_token, token_manager, security_scopes)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Missing permissions" in exc_info.value.detail
|
||||
|
||||
def test_check_scopes_with_no_required_scopes(
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""Test that no required scopes passes validation."""
|
||||
_, access_token = token_manager.create_token(
|
||||
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
|
||||
)
|
||||
|
||||
security_scopes = SecurityScopes(scopes=[])
|
||||
|
||||
# Should not raise an exception
|
||||
try:
|
||||
auth_security.check_scopes(access_token, token_manager, security_scopes)
|
||||
except HTTPException:
|
||||
pytest.fail("Empty required scopes should not raise HTTPException")
|
||||
|
||||
|
||||
class TestGetAndReturnTokens:
|
||||
"""Test simple token return functions."""
|
||||
|
||||
def test_get_and_return_access_token(self):
|
||||
"""Test that access token is returned unchanged."""
|
||||
test_token = "test_access_token"
|
||||
result = auth_security.get_and_return_access_token(test_token)
|
||||
assert result == test_token
|
||||
|
||||
def test_get_and_return_refresh_token(self):
|
||||
"""Test that refresh token is returned unchanged."""
|
||||
test_token = "test_refresh_token"
|
||||
result = auth_security.get_and_return_refresh_token(test_token)
|
||||
assert result == test_token
|
||||
398
backend/tests/auth/test_utils.py
Normal file
398
backend/tests/auth/test_utils.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Tests for auth.utils module."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from fastapi import HTTPException, Response
|
||||
|
||||
import auth.utils as auth_utils
|
||||
import auth.token_manager as auth_token_manager
|
||||
import users.user.schema as user_schema
|
||||
|
||||
|
||||
class TestAuthenticateUser:
|
||||
"""Test user authentication function."""
|
||||
|
||||
def test_authenticate_user_success(
|
||||
self, password_hasher, mock_db, sample_user_read
|
||||
):
|
||||
"""Test successful user authentication."""
|
||||
# Arrange
|
||||
username = "testuser"
|
||||
password = "TestPassword123!"
|
||||
|
||||
# Create a user with hashed password
|
||||
hashed_password = password_hasher.hash_password(password)
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = sample_user_read.id
|
||||
mock_user.password = hashed_password
|
||||
mock_user.username = username
|
||||
|
||||
# Mock the CRUD function to return our user
|
||||
with patch("auth.utils.users_crud.authenticate_user", return_value=mock_user):
|
||||
# Act
|
||||
result = auth_utils.authenticate_user(
|
||||
username, password, password_hasher, mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
def test_authenticate_user_invalid_username(self, password_hasher, mock_db):
|
||||
"""Test authentication with invalid username raises 401."""
|
||||
with patch("auth.utils.users_crud.authenticate_user", return_value=None):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_utils.authenticate_user(
|
||||
"nonexistent", "password", password_hasher, mock_db
|
||||
)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid username" in exc_info.value.detail
|
||||
|
||||
def test_authenticate_user_invalid_password(self, password_hasher, mock_db):
|
||||
"""Test authentication with invalid password raises 401."""
|
||||
# Arrange
|
||||
username = "testuser"
|
||||
correct_password = "CorrectPassword123!"
|
||||
wrong_password = "WrongPassword123!"
|
||||
|
||||
hashed_password = password_hasher.hash_password(correct_password)
|
||||
mock_user = MagicMock()
|
||||
mock_user.password = hashed_password
|
||||
|
||||
with patch("auth.utils.users_crud.authenticate_user", return_value=mock_user):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_utils.authenticate_user(
|
||||
username, wrong_password, password_hasher, mock_db
|
||||
)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid password" in exc_info.value.detail
|
||||
|
||||
def test_authenticate_user_updates_password_hash_if_needed(
|
||||
self, password_hasher, mock_db, sample_user_read
|
||||
):
|
||||
"""Test that password hash is updated if algorithm changed."""
|
||||
# Arrange
|
||||
username = "testuser"
|
||||
password = "TestPassword123!"
|
||||
|
||||
# Use a different hasher to simulate old hash
|
||||
from pwdlib.hashers.bcrypt import BcryptHasher
|
||||
|
||||
old_hasher_instance = BcryptHasher()
|
||||
old_hash = old_hasher_instance.hash(password)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = sample_user_read.id
|
||||
mock_user.password = old_hash
|
||||
mock_user.username = username
|
||||
|
||||
with patch("auth.utils.users_crud.authenticate_user", return_value=mock_user):
|
||||
with patch("auth.utils.users_crud.edit_user_password") as mock_edit:
|
||||
# Act
|
||||
result = auth_utils.authenticate_user(
|
||||
username, password, password_hasher, mock_db
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
# Password should be updated since we're using different hasher
|
||||
mock_edit.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateTokens:
|
||||
"""Test token creation function."""
|
||||
|
||||
def test_create_tokens_generates_all_tokens(self, token_manager, sample_user_read):
|
||||
"""Test that create_tokens generates all required tokens."""
|
||||
# Act
|
||||
(
|
||||
session_id,
|
||||
access_token_exp,
|
||||
access_token,
|
||||
refresh_token_exp,
|
||||
refresh_token,
|
||||
csrf_token,
|
||||
) = auth_utils.create_tokens(sample_user_read, token_manager)
|
||||
|
||||
# Assert
|
||||
assert session_id is not None
|
||||
assert isinstance(session_id, str)
|
||||
assert len(session_id) > 0
|
||||
|
||||
assert access_token_exp is not None
|
||||
assert isinstance(access_token_exp, datetime)
|
||||
assert access_token_exp > datetime.now(timezone.utc)
|
||||
|
||||
assert access_token is not None
|
||||
assert isinstance(access_token, str)
|
||||
assert len(access_token) > 0
|
||||
|
||||
assert refresh_token_exp is not None
|
||||
assert isinstance(refresh_token_exp, datetime)
|
||||
assert refresh_token_exp > datetime.now(timezone.utc)
|
||||
|
||||
assert refresh_token is not None
|
||||
assert isinstance(refresh_token, str)
|
||||
assert len(refresh_token) > 0
|
||||
|
||||
assert csrf_token is not None
|
||||
assert isinstance(csrf_token, str)
|
||||
assert len(csrf_token) >= 32
|
||||
|
||||
def test_create_tokens_with_custom_session_id(
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""Test that create_tokens uses provided session ID."""
|
||||
# Arrange
|
||||
custom_session_id = "custom-session-123"
|
||||
|
||||
# Act
|
||||
(
|
||||
session_id,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = auth_utils.create_tokens(sample_user_read, token_manager, custom_session_id)
|
||||
|
||||
# Assert
|
||||
assert session_id == custom_session_id
|
||||
|
||||
def test_create_tokens_refresh_expires_after_access(
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""Test that refresh token expires after access token."""
|
||||
# Act
|
||||
(
|
||||
_,
|
||||
access_token_exp,
|
||||
_,
|
||||
refresh_token_exp,
|
||||
_,
|
||||
_,
|
||||
) = auth_utils.create_tokens(sample_user_read, token_manager)
|
||||
|
||||
# Assert
|
||||
assert refresh_token_exp > access_token_exp
|
||||
|
||||
def test_create_tokens_generates_unique_tokens(
|
||||
self, token_manager, sample_user_read
|
||||
):
|
||||
"""Test that multiple calls generate unique tokens."""
|
||||
# Act
|
||||
(_, _, access_token1, _, refresh_token1, csrf_token1) = (
|
||||
auth_utils.create_tokens(sample_user_read, token_manager)
|
||||
)
|
||||
(_, _, access_token2, _, refresh_token2, csrf_token2) = (
|
||||
auth_utils.create_tokens(sample_user_read, token_manager)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert access_token1 != access_token2
|
||||
assert refresh_token1 != refresh_token2
|
||||
assert csrf_token1 != csrf_token2
|
||||
|
||||
|
||||
class TestCompleteLogin:
|
||||
"""Test complete_login function."""
|
||||
|
||||
def test_complete_login_for_web_client(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test complete_login for web client sets cookies and returns tokens."""
|
||||
# Arrange
|
||||
response = Response()
|
||||
client_type = "web"
|
||||
|
||||
with patch("auth.utils.session_utils.create_session"):
|
||||
# Act
|
||||
result = auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "session_id" in result
|
||||
assert "access_token" in result
|
||||
assert "csrf_token" in result
|
||||
assert "token_type" in result
|
||||
assert "expires_in" in result
|
||||
|
||||
assert result["token_type"] == "bearer"
|
||||
assert isinstance(result["expires_in"], int)
|
||||
|
||||
# Check that refresh token cookie was set
|
||||
assert "endurain_refresh_token" in response.headers.get("set-cookie", "")
|
||||
|
||||
def test_complete_login_for_mobile_client(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test complete_login for mobile client returns tokens."""
|
||||
# Arrange
|
||||
response = Response()
|
||||
client_type = "mobile"
|
||||
|
||||
with patch("auth.utils.session_utils.create_session"):
|
||||
# Act
|
||||
result = auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert "session_id" in result
|
||||
assert "access_token" in result
|
||||
assert "csrf_token" in result
|
||||
assert result["token_type"] == "bearer"
|
||||
|
||||
def test_complete_login_creates_session(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test that complete_login creates a session in the database."""
|
||||
# Arrange
|
||||
response = Response()
|
||||
client_type = "web"
|
||||
|
||||
with patch("auth.utils.session_utils.create_session") as mock_create_session:
|
||||
# Act
|
||||
result = auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_create_session.assert_called_once()
|
||||
call_args = mock_create_session.call_args
|
||||
|
||||
# Verify session_id matches
|
||||
assert call_args[0][0] == result["session_id"]
|
||||
# Verify user was passed
|
||||
assert call_args[0][1] == sample_user_read
|
||||
# Verify request was passed
|
||||
assert call_args[0][2] == mock_request
|
||||
|
||||
def test_complete_login_invalid_client_type_raises_error(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test that invalid client type raises 403."""
|
||||
# Arrange
|
||||
response = Response()
|
||||
invalid_client_type = "invalid"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
invalid_client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Invalid client type" in exc_info.value.detail
|
||||
|
||||
def test_complete_login_sets_secure_cookie_for_https(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test that secure flag is set on cookie when using HTTPS."""
|
||||
# Arrange
|
||||
response = Response()
|
||||
client_type = "web"
|
||||
|
||||
with patch("auth.utils.session_utils.create_session"):
|
||||
with patch.dict("os.environ", {"FRONTEND_PROTOCOL": "https"}):
|
||||
# Act
|
||||
auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
# Assert
|
||||
set_cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "secure" in set_cookie_header.lower()
|
||||
|
||||
def test_complete_login_cookie_attributes(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test that refresh token cookie has correct security attributes."""
|
||||
# Arrange
|
||||
response = Response()
|
||||
client_type = "web"
|
||||
|
||||
with patch("auth.utils.session_utils.create_session"):
|
||||
# Act
|
||||
auth_utils.complete_login(
|
||||
response,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
# Assert
|
||||
set_cookie_header = response.headers.get("set-cookie", "")
|
||||
assert "endurain_refresh_token" in set_cookie_header
|
||||
assert "httponly" in set_cookie_header.lower()
|
||||
assert "samesite=strict" in set_cookie_header.lower()
|
||||
assert "path=/" in set_cookie_header.lower()
|
||||
|
||||
def test_complete_login_returns_different_tokens_on_each_call(
|
||||
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
|
||||
):
|
||||
"""Test that each login generates unique tokens."""
|
||||
# Arrange
|
||||
response1 = Response()
|
||||
response2 = Response()
|
||||
client_type = "web"
|
||||
|
||||
with patch("auth.utils.session_utils.create_session"):
|
||||
# Act
|
||||
result1 = auth_utils.complete_login(
|
||||
response1,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
result2 = auth_utils.complete_login(
|
||||
response2,
|
||||
mock_request,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
password_hasher,
|
||||
token_manager,
|
||||
mock_db,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result1["session_id"] != result2["session_id"]
|
||||
assert result1["access_token"] != result2["access_token"]
|
||||
assert result1["csrf_token"] != result2["csrf_token"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""
|
||||
Tests for Endurain session module backend application.
|
||||
"""
|
||||
@@ -1,994 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class TestLoginEndpointSecurity:
|
||||
"""
|
||||
Test suite for verifying the security and behavior of the login endpoint.
|
||||
|
||||
This class contains tests that cover various scenarios for the login endpoint, including:
|
||||
- Successful login without Multi-Factor Authentication (MFA) for different client types.
|
||||
- Login attempts when MFA is required, ensuring the correct response and MFA flow.
|
||||
- Handling of invalid client types, ensuring forbidden access is enforced.
|
||||
- Login attempts with invalid credentials, ensuring proper error handling.
|
||||
- Login attempts with inactive users, ensuring access is denied as expected.
|
||||
|
||||
Each test uses extensive mocking to simulate authentication, user activity checks, MFA status, and session/token creation, allowing for isolated and reliable testing of the endpoint's logic and security requirements.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_type, expected_status, returns_tokens",
|
||||
[
|
||||
("web", status.HTTP_200_OK, False),
|
||||
("mobile", status.HTTP_200_OK, True),
|
||||
],
|
||||
)
|
||||
def test_login_without_mfa(
|
||||
self,
|
||||
fast_api_app,
|
||||
fast_api_client,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
expected_status,
|
||||
returns_tokens,
|
||||
):
|
||||
"""
|
||||
Test the login endpoint behavior when Multi-Factor Authentication (MFA) is not enabled for the user.
|
||||
|
||||
This test verifies that:
|
||||
- The login process completes successfully without requiring MFA.
|
||||
- The correct response is returned based on whether tokens are expected.
|
||||
- The appropriate authentication, user activity, and MFA checks are patched and simulated.
|
||||
- The fake store is not called during the process.
|
||||
|
||||
Args:
|
||||
fast_api_app: The FastAPI application instance under test.
|
||||
fast_api_client: The test client for making HTTP requests to the FastAPI app.
|
||||
sample_user_read: A sample user object returned by the authentication mock.
|
||||
client_type: The type of client making the request (used in headers and app state).
|
||||
expected_status: The expected HTTP status code of the response.
|
||||
returns_tokens: Boolean indicating if the endpoint should return tokens or just a session ID.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.profile_utils.is_mfa_enabled_for_user"
|
||||
) as mock_mfa, patch(
|
||||
"session.router.auth_utils.complete_login"
|
||||
) as mock_complete:
|
||||
mock_auth.return_value = sample_user_read
|
||||
mock_mfa.return_value = False
|
||||
mock_complete.return_value = (
|
||||
{"session_id": "test-session"}
|
||||
if not returns_tokens
|
||||
else {
|
||||
"access_token": "token",
|
||||
"refresh_token": "refresh",
|
||||
"session_id": "session",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 900,
|
||||
}
|
||||
)
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/login",
|
||||
data={"username": "testuser", "password": "secret"},
|
||||
headers={"X-Client-Type": client_type},
|
||||
)
|
||||
assert resp.status_code == expected_status
|
||||
body = resp.json()
|
||||
if returns_tokens:
|
||||
assert body["access_token"] == "token"
|
||||
assert body["refresh_token"] == "refresh"
|
||||
assert body["session_id"] == "session"
|
||||
assert body["token_type"] == "Bearer"
|
||||
assert isinstance(body["expires_in"], int)
|
||||
else:
|
||||
assert body == {"session_id": "test-session"}
|
||||
assert fast_api_app.state.fake_store.calls == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_type, expected_status",
|
||||
[
|
||||
("web", status.HTTP_202_ACCEPTED),
|
||||
("mobile", status.HTTP_200_OK),
|
||||
],
|
||||
)
|
||||
def test_login_with_mfa_required(
|
||||
self,
|
||||
fast_api_app,
|
||||
fast_api_client,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
expected_status,
|
||||
):
|
||||
"""
|
||||
Test the login endpoint when Multi-Factor Authentication (MFA) is required.
|
||||
|
||||
This test verifies that when a user with MFA enabled attempts to log in,
|
||||
the API responds with the correct status code and indicates that MFA is required.
|
||||
It mocks the authentication, user activity check, and MFA status check to simulate
|
||||
the scenario where MFA is enabled for the user.
|
||||
|
||||
Args:
|
||||
fast_api_app: The FastAPI application instance under test.
|
||||
fast_api_client: The test client for making HTTP requests to the FastAPI app.
|
||||
sample_user_read: A sample user object returned by the authentication mock.
|
||||
client_type: The type of client making the request (used in headers).
|
||||
expected_status: The expected HTTP status code for the response.
|
||||
|
||||
Asserts:
|
||||
- The response status code matches the expected status.
|
||||
- The response JSON contains 'mfa_required' set to True.
|
||||
- The response JSON contains the correct 'username'.
|
||||
- The fake_store in the app state records the correct call.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch("session.router.profile_utils.is_mfa_enabled_for_user") as mock_mfa:
|
||||
mock_auth.return_value = sample_user_read
|
||||
mock_mfa.return_value = True
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/login",
|
||||
data={"username": "testuser", "password": "secret"},
|
||||
headers={"X-Client-Type": client_type},
|
||||
)
|
||||
assert resp.status_code == expected_status
|
||||
body = resp.json()
|
||||
assert body["mfa_required"] is True
|
||||
assert body["username"] == "testuser"
|
||||
assert fast_api_app.state.fake_store.calls == [
|
||||
("testuser", sample_user_read.id)
|
||||
]
|
||||
|
||||
def test_invalid_client_type_forbidden(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test that a login attempt with an invalid client type returns a 403 Forbidden response.
|
||||
|
||||
This test sets the application's client type to "desktop" and mocks the authentication,
|
||||
user activity check, MFA status, token creation, and session creation utilities. It then
|
||||
sends a POST request to the "/auth/login" endpoint with the "X-Client-Type" header set to "desktop".
|
||||
The test asserts that the response status code is 403 Forbidden and the response detail
|
||||
indicates an invalid client type.
|
||||
|
||||
Args:
|
||||
fast_api_app: The FastAPI application instance.
|
||||
fast_api_client: The test client for making HTTP requests.
|
||||
sample_user_read: A sample user object returned by the authentication mock.
|
||||
"""
|
||||
fast_api_app.state._client_type = "desktop"
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.profile_utils.is_mfa_enabled_for_user"
|
||||
) as mock_mfa, patch(
|
||||
"session.router.auth_utils.create_tokens"
|
||||
) as mock_create_tokens, patch(
|
||||
"session.router.session_utils.create_session"
|
||||
) as mock_create_session:
|
||||
mock_auth.return_value = sample_user_read
|
||||
mock_mfa.return_value = False
|
||||
|
||||
mock_create_tokens.return_value = (
|
||||
"sid",
|
||||
object(),
|
||||
"acc",
|
||||
object(),
|
||||
"ref",
|
||||
"csrf",
|
||||
)
|
||||
mock_create_session.return_value = None
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/login",
|
||||
data={"username": "x", "password": "y"},
|
||||
headers={"X-Client-Type": "desktop"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert resp.json()["detail"] == "Invalid client type"
|
||||
|
||||
def test_login_with_invalid_credentials(self, password_hasher, mock_db):
|
||||
"""
|
||||
Test that the login endpoint raises an HTTPException with status code 401
|
||||
when invalid credentials are provided. Mocks the authenticate_user function
|
||||
to simulate authentication failure and verifies that the exception is raised
|
||||
with the correct status code and detail.
|
||||
"""
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth:
|
||||
mock_auth.side_effect = HTTPException(
|
||||
status_code=401, detail="Invalid username"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_auth("invalid", "password", password_hasher, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_login_with_inactive_user(self, sample_inactive_user):
|
||||
"""
|
||||
Test that the login endpoint raises an HTTPException with status code 403
|
||||
when attempting to authenticate an inactive user.
|
||||
|
||||
This test mocks the authentication and user activity check utilities to simulate
|
||||
the scenario where a user is found but is inactive. It asserts that the correct
|
||||
exception is raised with the expected status code.
|
||||
"""
|
||||
with patch("session.router.auth_utils.authenticate_user") as mock_auth:
|
||||
with patch("session.router.users_utils.check_user_is_active") as mock_check:
|
||||
mock_auth.return_value = sample_inactive_user
|
||||
mock_check.side_effect = HTTPException(
|
||||
status_code=403, detail="User is inactive"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_check(sample_inactive_user)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
class TestMFAVerifyEndpoint:
|
||||
"""
|
||||
Test suite for the MFA verification endpoint (/auth/mfa/verify).
|
||||
|
||||
This class contains tests that cover various scenarios for the MFA verification endpoint, including:
|
||||
- Successful MFA verification and login for different client types (web and mobile).
|
||||
- Handling of cases where no pending MFA login is found.
|
||||
- Handling of invalid MFA codes.
|
||||
- Handling of cases where the user is not found after MFA verification.
|
||||
- Handling of inactive users after MFA verification.
|
||||
- Handling of invalid client types during MFA verification.
|
||||
|
||||
Each test uses mocking to simulate the MFA verification flow, user lookup, and session creation.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_type, expected_status, returns_tokens",
|
||||
[
|
||||
("web", status.HTTP_200_OK, False),
|
||||
("mobile", status.HTTP_200_OK, True),
|
||||
],
|
||||
)
|
||||
def test_mfa_verify_success(
|
||||
self,
|
||||
fast_api_app,
|
||||
fast_api_client,
|
||||
sample_user_read,
|
||||
client_type,
|
||||
expected_status,
|
||||
returns_tokens,
|
||||
):
|
||||
"""
|
||||
Test successful MFA verification and login completion.
|
||||
|
||||
This test verifies that when a valid MFA code is provided for a pending login,
|
||||
the API successfully completes the login process and returns appropriate tokens
|
||||
based on the client type.
|
||||
|
||||
Args:
|
||||
fast_api_app: The FastAPI application instance under test.
|
||||
fast_api_client: The test client for making HTTP requests.
|
||||
sample_user_read: A sample user object.
|
||||
client_type: The type of client ("web" or "mobile").
|
||||
expected_status: The expected HTTP status code.
|
||||
returns_tokens: Boolean indicating if tokens should be returned.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
|
||||
# Setup pending MFA login
|
||||
pending_store = fast_api_app.state.fake_store
|
||||
pending_store._store = {"testuser": sample_user_read.id}
|
||||
|
||||
with patch(
|
||||
"session.router.profile_utils.verify_user_mfa"
|
||||
) as mock_verify_mfa, patch(
|
||||
"session.router.users_crud.get_user_by_id"
|
||||
) as mock_get_user, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.auth_utils.complete_login"
|
||||
) as mock_complete:
|
||||
mock_verify_mfa.return_value = True
|
||||
mock_get_user.return_value = sample_user_read
|
||||
mock_complete.return_value = (
|
||||
{"session_id": "test-session"}
|
||||
if not returns_tokens
|
||||
else {
|
||||
"access_token": "token",
|
||||
"refresh_token": "refresh",
|
||||
"session_id": "session",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 900,
|
||||
}
|
||||
)
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/mfa/verify",
|
||||
json={"username": "testuser", "mfa_code": "123456"},
|
||||
headers={"X-Client-Type": client_type},
|
||||
)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
body = resp.json()
|
||||
if returns_tokens:
|
||||
assert body["access_token"] == "token"
|
||||
assert body["refresh_token"] == "refresh"
|
||||
assert body["session_id"] == "session"
|
||||
else:
|
||||
assert body["session_id"] == "test-session"
|
||||
|
||||
def test_mfa_verify_no_pending_login(self, fast_api_app, fast_api_client):
|
||||
"""
|
||||
Test MFA verification when no pending login is found.
|
||||
|
||||
This test verifies that when attempting to verify MFA without a pending login,
|
||||
the API returns a 400 Bad Request error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.fake_store._store = {}
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/mfa/verify",
|
||||
json={"username": "testuser", "mfa_code": "123456"},
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert "No pending MFA login" in resp.json()["detail"]
|
||||
|
||||
def test_mfa_verify_invalid_code(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test MFA verification with an invalid MFA code.
|
||||
|
||||
This test verifies that when an invalid MFA code is provided,
|
||||
the API returns a 401 Unauthorized error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.fake_store._store = {"testuser": sample_user_read.id}
|
||||
|
||||
with patch("session.router.profile_utils.verify_user_mfa") as mock_verify_mfa:
|
||||
mock_verify_mfa.return_value = False
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/mfa/verify",
|
||||
json={"username": "testuser", "mfa_code": "999999"},
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "Invalid MFA code" in resp.json()["detail"]
|
||||
|
||||
def test_mfa_verify_user_not_found(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test MFA verification when user is not found after verification.
|
||||
|
||||
This test verifies that when a user cannot be found in the database
|
||||
after MFA verification, the API returns a 404 Not Found error and
|
||||
cleans up the pending login.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.fake_store._store = {"testuser": sample_user_read.id}
|
||||
|
||||
with patch(
|
||||
"session.router.profile_utils.verify_user_mfa"
|
||||
) as mock_verify_mfa, patch(
|
||||
"session.router.users_crud.get_user_by_id"
|
||||
) as mock_get_user:
|
||||
mock_verify_mfa.return_value = True
|
||||
mock_get_user.return_value = None
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/mfa/verify",
|
||||
json={"username": "testuser", "mfa_code": "123456"},
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "User not found" in resp.json()["detail"]
|
||||
|
||||
def test_mfa_verify_inactive_user(
|
||||
self, fast_api_app, fast_api_client, sample_inactive_user
|
||||
):
|
||||
"""
|
||||
Test MFA verification with an inactive user.
|
||||
|
||||
This test verifies that when an inactive user attempts to complete MFA login,
|
||||
the API returns a 403 Forbidden error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.fake_store._store = {"inactive": sample_inactive_user.id}
|
||||
|
||||
with patch(
|
||||
"session.router.profile_utils.verify_user_mfa"
|
||||
) as mock_verify_mfa, patch(
|
||||
"session.router.users_crud.get_user_by_id"
|
||||
) as mock_get_user, patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
) as mock_check_active:
|
||||
mock_verify_mfa.return_value = True
|
||||
mock_get_user.return_value = sample_inactive_user
|
||||
mock_check_active.side_effect = HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User is inactive"
|
||||
)
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/mfa/verify",
|
||||
json={"username": "inactive", "mfa_code": "123456"},
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "User is inactive" in resp.json()["detail"]
|
||||
|
||||
|
||||
class TestRefreshTokenEndpoint:
|
||||
"""
|
||||
Test suite for the refresh token endpoint (/auth/refresh).
|
||||
|
||||
This class contains tests that cover various scenarios for the token refresh endpoint, including:
|
||||
- Successful token refresh for different client types (web and mobile).
|
||||
- Handling of session not found errors.
|
||||
- Handling of invalid refresh token hash mismatches.
|
||||
- Handling of inactive users during refresh.
|
||||
- Handling of invalid client types during refresh.
|
||||
|
||||
Each test uses mocking to simulate token validation, session retrieval, and token creation.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_type, expected_status, returns_tokens",
|
||||
[
|
||||
("web", status.HTTP_200_OK, False),
|
||||
("mobile", status.HTTP_200_OK, True),
|
||||
],
|
||||
)
|
||||
def test_refresh_token_success(
|
||||
self,
|
||||
fast_api_app,
|
||||
fast_api_client,
|
||||
sample_user_read,
|
||||
password_hasher,
|
||||
client_type,
|
||||
expected_status,
|
||||
returns_tokens,
|
||||
):
|
||||
"""
|
||||
Test successful token refresh.
|
||||
|
||||
This test verifies that when a valid refresh token is provided,
|
||||
the API successfully creates new tokens and returns them based on client type.
|
||||
|
||||
Args:
|
||||
fast_api_app: The FastAPI application instance under test.
|
||||
fast_api_client: The test client for making HTTP requests.
|
||||
sample_user_read: A sample user object.
|
||||
password_hasher: The password hasher instance.
|
||||
mock_db: Mock database session.
|
||||
client_type: The type of client ("web" or "mobile").
|
||||
expected_status: The expected HTTP status code.
|
||||
returns_tokens: Boolean indicating if tokens should be returned.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
fast_api_app.state.mock_user_id = sample_user_read.id
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
), patch(
|
||||
"session.router.users_crud.get_user_by_id", return_value=sample_user_read
|
||||
), patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.auth_utils.create_tokens"
|
||||
) as mock_create_tokens, patch(
|
||||
"session.router.session_utils.edit_session"
|
||||
), patch(
|
||||
"session.router.auth_utils.create_response_with_tokens",
|
||||
side_effect=lambda r, a, rf, c: r,
|
||||
):
|
||||
# Set up proper mock for create_tokens with timestamp
|
||||
mock_access_exp = MagicMock()
|
||||
mock_access_exp.timestamp.return_value = 1234567890
|
||||
mock_refresh_exp = MagicMock()
|
||||
mock_create_tokens.return_value = (
|
||||
"new-session-id",
|
||||
mock_access_exp,
|
||||
"new_access_token",
|
||||
mock_refresh_exp,
|
||||
"new_refresh_token",
|
||||
"new_csrf_token",
|
||||
)
|
||||
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/refresh",
|
||||
headers={"X-Client-Type": client_type},
|
||||
)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
body = resp.json()
|
||||
if returns_tokens:
|
||||
assert body["access_token"] == "new_access_token"
|
||||
assert body["refresh_token"] == "new_refresh_token"
|
||||
assert body["session_id"] == "new-session-id"
|
||||
assert body["token_type"] == "bearer"
|
||||
else:
|
||||
assert body["session_id"] == "new-session-id"
|
||||
|
||||
def test_refresh_token_session_not_found(self, fast_api_app, fast_api_client):
|
||||
"""
|
||||
Test token refresh when session is not found.
|
||||
|
||||
This test verifies that when attempting to refresh with a session ID
|
||||
that doesn't exist, the API returns a 404 Not Found error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_session_id = "nonexistent-session"
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
with patch("session.router.session_crud.get_session_by_id", return_value=None):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/refresh",
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Session not found" in resp.json()["detail"]
|
||||
|
||||
def test_refresh_token_invalid_hash(
|
||||
self, fast_api_app, fast_api_client, password_hasher
|
||||
):
|
||||
"""
|
||||
Test token refresh with invalid refresh token hash.
|
||||
|
||||
This test verifies that when the refresh token hash doesn't match
|
||||
the stored hash, the API returns a 401 Unauthorized error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_refresh_token = "wrong_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password("different_token")
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "wrong_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/refresh",
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "Invalid refresh token" in resp.json()["detail"]
|
||||
|
||||
def test_refresh_token_inactive_user(
|
||||
self, fast_api_app, fast_api_client, sample_inactive_user, password_hasher
|
||||
):
|
||||
"""
|
||||
Test token refresh with an inactive user.
|
||||
|
||||
This test verifies that when an inactive user attempts to refresh tokens,
|
||||
the API returns a 403 Forbidden error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.mock_user_id = sample_inactive_user.id
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
), patch(
|
||||
"session.router.users_crud.get_user_by_id",
|
||||
return_value=sample_inactive_user,
|
||||
), patch(
|
||||
"session.router.users_utils.check_user_is_active",
|
||||
side_effect=HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User is inactive"
|
||||
),
|
||||
):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/refresh",
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_refresh_token_invalid_client_type(
|
||||
self, fast_api_app, fast_api_client, sample_user_read, password_hasher
|
||||
):
|
||||
"""
|
||||
Test token refresh with an invalid client type.
|
||||
|
||||
This test verifies that when an invalid client type is provided,
|
||||
the API returns a 403 Forbidden error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "desktop"
|
||||
fast_api_app.state.mock_user_id = sample_user_read.id
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
), patch(
|
||||
"session.router.users_crud.get_user_by_id", return_value=sample_user_read
|
||||
), patch(
|
||||
"session.router.users_utils.check_user_is_active"
|
||||
), patch(
|
||||
"session.router.auth_utils.create_tokens"
|
||||
) as mock_create_tokens, patch(
|
||||
"session.router.session_utils.edit_session"
|
||||
):
|
||||
# Set up proper mock for create_tokens with timestamp
|
||||
mock_access_exp = MagicMock()
|
||||
mock_access_exp.timestamp.return_value = 1234567890
|
||||
mock_refresh_exp = MagicMock()
|
||||
mock_create_tokens.return_value = (
|
||||
"new-session-id",
|
||||
mock_access_exp,
|
||||
"new_access_token",
|
||||
mock_refresh_exp,
|
||||
"new_refresh_token",
|
||||
"new_csrf_token",
|
||||
)
|
||||
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/refresh",
|
||||
headers={"X-Client-Type": "desktop"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "Invalid client type" in resp.json()["detail"]
|
||||
|
||||
|
||||
class TestLogoutEndpoint:
|
||||
"""
|
||||
Test suite for the logout endpoint (/auth/logout).
|
||||
|
||||
This class contains tests that cover various scenarios for the logout endpoint, including:
|
||||
- Successful logout for different client types (web and mobile).
|
||||
- Cookie clearing for web clients.
|
||||
- Handling of invalid refresh tokens during logout.
|
||||
- Handling of session not found during logout (should still succeed).
|
||||
- Handling of invalid client types during logout.
|
||||
|
||||
Each test uses mocking to simulate token validation, session retrieval, and session deletion.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"client_type, expected_status",
|
||||
[
|
||||
("web", status.HTTP_200_OK),
|
||||
("mobile", status.HTTP_200_OK),
|
||||
],
|
||||
)
|
||||
def test_logout_success(
|
||||
self,
|
||||
fast_api_app,
|
||||
fast_api_client,
|
||||
password_hasher,
|
||||
client_type,
|
||||
expected_status,
|
||||
):
|
||||
"""
|
||||
Test successful logout.
|
||||
|
||||
This test verifies that when a valid access and refresh token are provided,
|
||||
the API successfully deletes the session and returns a success message.
|
||||
|
||||
Args:
|
||||
fast_api_app: The FastAPI application instance under test.
|
||||
fast_api_client: The test client for making HTTP requests.
|
||||
password_hasher: The password hasher instance.
|
||||
client_type: The type of client ("web" or "mobile").
|
||||
expected_status: The expected HTTP status code.
|
||||
"""
|
||||
fast_api_app.state._client_type = client_type
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password(
|
||||
"refresh_token_value"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
), patch("session.router.session_crud.delete_session") as mock_delete:
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/logout",
|
||||
headers={"X-Client-Type": client_type},
|
||||
)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
assert resp.json()["message"] == "Logout successful"
|
||||
mock_delete.assert_called_once()
|
||||
|
||||
# Check cookies are cleared for web clients
|
||||
if client_type == "web":
|
||||
# The response should have set-cookie headers to clear cookies
|
||||
assert "set-cookie" in resp.headers or resp.cookies
|
||||
|
||||
def test_logout_invalid_refresh_token(
|
||||
self, fast_api_app, fast_api_client, password_hasher
|
||||
):
|
||||
"""
|
||||
Test logout with an invalid refresh token.
|
||||
|
||||
This test verifies that when the refresh token hash doesn't match,
|
||||
the API returns a 401 Unauthorized error.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_refresh_token = "wrong_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
mock_session.refresh_token = password_hasher.hash_password("different_token")
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "wrong_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/logout",
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "Invalid refresh token" in resp.json()["detail"]
|
||||
|
||||
def test_logout_session_not_found_still_succeeds(
|
||||
self, fast_api_app, fast_api_client
|
||||
):
|
||||
"""
|
||||
Test logout when session is not found (should still succeed).
|
||||
|
||||
This test verifies that when attempting to logout with a session ID
|
||||
that doesn't exist, the API still returns success (idempotent operation).
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
fast_api_app.state.mock_session_id = "nonexistent-session"
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
with patch("session.router.session_crud.get_session_by_id", return_value=None):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/logout",
|
||||
headers={"X-Client-Type": "web"},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_200_OK
|
||||
assert resp.json()["message"] == "Logout successful"
|
||||
|
||||
def test_logout_invalid_client_type(self, fast_api_app, fast_api_client):
|
||||
"""
|
||||
Test logout with an invalid client type.
|
||||
|
||||
This test verifies that when an invalid client type is provided,
|
||||
the API returns a 401 Unauthorized error (client type validation
|
||||
happens after authentication in the dependency chain).
|
||||
"""
|
||||
fast_api_app.state._client_type = "desktop"
|
||||
fast_api_app.state.mock_session_id = "test-session-id"
|
||||
fast_api_app.state.mock_user_id = 1
|
||||
fast_api_app.state.mock_refresh_token = "refresh_token_value"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.id = "test-session-id"
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_session_by_id", return_value=mock_session
|
||||
):
|
||||
# Set cookies on client instance (new API)
|
||||
fast_api_client.cookies.set("endurain_access_token", "access_token")
|
||||
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
|
||||
|
||||
resp = fast_api_client.post(
|
||||
"/auth/logout",
|
||||
headers={"X-Client-Type": "desktop"},
|
||||
)
|
||||
|
||||
# Client type validation happens in the header_client_type_scheme dependency
|
||||
# which runs after authentication, so we get 401 due to invalid client type
|
||||
# being rejected by the scheme validator
|
||||
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestSessionsEndpoints:
|
||||
"""
|
||||
Test suite for the sessions management endpoints.
|
||||
|
||||
This class contains tests for:
|
||||
- GET /sessions/user/{user_id} - Retrieve all sessions for a user
|
||||
- DELETE /sessions/{session_id}/user/{user_id} - Delete a specific session
|
||||
|
||||
Each test uses mocking to simulate authentication, authorization, and database operations.
|
||||
"""
|
||||
|
||||
def test_read_sessions_user_success(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of user sessions.
|
||||
|
||||
This test verifies that when a valid access token is provided with
|
||||
appropriate scopes, the API returns all sessions for the user.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
|
||||
mock_sessions = [
|
||||
{
|
||||
"id": "session-1",
|
||||
"user_id": sample_user_read.id,
|
||||
"device_type": "desktop",
|
||||
"browser": "Chrome",
|
||||
},
|
||||
{
|
||||
"id": "session-2",
|
||||
"user_id": sample_user_read.id,
|
||||
"device_type": "mobile",
|
||||
"browser": "Safari",
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.get_user_sessions", return_value=mock_sessions
|
||||
) as mock_get_sessions:
|
||||
resp = fast_api_client.get(
|
||||
f"/sessions/user/{sample_user_read.id}",
|
||||
headers={
|
||||
"X-Client-Type": "web",
|
||||
"Authorization": "Bearer access_token",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_200_OK
|
||||
assert len(resp.json()) == 2
|
||||
assert resp.json()[0]["id"] == "session-1"
|
||||
assert resp.json()[1]["id"] == "session-2"
|
||||
mock_get_sessions.assert_called_once()
|
||||
|
||||
def test_read_sessions_user_empty_list(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test retrieval of user sessions when no sessions exist.
|
||||
|
||||
This test verifies that when a user has no active sessions,
|
||||
the API returns an empty list.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
|
||||
with patch("session.router.session_crud.get_user_sessions", return_value=[]):
|
||||
resp = fast_api_client.get(
|
||||
f"/sessions/user/{sample_user_read.id}",
|
||||
headers={
|
||||
"X-Client-Type": "web",
|
||||
"Authorization": "Bearer access_token",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_200_OK
|
||||
assert resp.json() == []
|
||||
|
||||
def test_delete_session_success(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test successful deletion of a user session.
|
||||
|
||||
This test verifies that when a valid access token is provided with
|
||||
appropriate scopes, the API successfully deletes the specified session.
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
session_id = "session-to-delete"
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.delete_session", return_value=True
|
||||
) as mock_delete:
|
||||
resp = fast_api_client.delete(
|
||||
f"/sessions/{session_id}/user/{sample_user_read.id}",
|
||||
headers={
|
||||
"X-Client-Type": "web",
|
||||
"Authorization": "Bearer access_token",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_200_OK
|
||||
# Verify delete_session was called with the correct session_id and user_id
|
||||
# (the third argument is the database session which we don't need to verify)
|
||||
assert mock_delete.called
|
||||
call_args = mock_delete.call_args[0]
|
||||
assert call_args[0] == session_id
|
||||
assert call_args[1] == sample_user_read.id
|
||||
|
||||
def test_delete_session_not_found(
|
||||
self, fast_api_app, fast_api_client, sample_user_read
|
||||
):
|
||||
"""
|
||||
Test deletion of a non-existent session.
|
||||
|
||||
This test verifies that when attempting to delete a session that doesn't exist,
|
||||
the API handles it appropriately (implementation-dependent behavior).
|
||||
"""
|
||||
fast_api_app.state._client_type = "web"
|
||||
session_id = "nonexistent-session"
|
||||
|
||||
with patch(
|
||||
"session.router.session_crud.delete_session",
|
||||
side_effect=HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
|
||||
),
|
||||
):
|
||||
resp = fast_api_client.delete(
|
||||
f"/sessions/{session_id}/user/{sample_user_read.id}",
|
||||
headers={
|
||||
"X-Client-Type": "web",
|
||||
"Authorization": "Bearer access_token",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert "Session not found" in resp.json()["detail"]
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user