Add tests for identity_providers module and update .gitignore

Added comprehensive unit tests for the identity_providers module, including CRUD operations, schema validation, and utility functions. Updated .gitignore to exclude deeper __pycache__ directories. Removed session test files and old __pycache__ files from the repository.
This commit is contained in:
João Vitória Silva
2025-12-19 11:25:57 +00:00
parent 435647d6c0
commit a2b32bc847
19 changed files with 3163 additions and 2137 deletions

3
.gitignore vendored
View File

@@ -15,6 +15,9 @@ backend/app/*/*/*/__pycache__/
backend/app/*.pyc
backend/tests/__pycache__/
backend/tests/*/__pycache__/
backend/tests/*/*/__pycache__/
backend/tests/*/*/*/__pycache__/
backend/tests/*/*/*/*/__pycache__/
# Tests
backend/.coverage

View File

@@ -0,0 +1 @@
"""Tests for identity_providers module."""

View File

@@ -0,0 +1,677 @@
"""Tests for identity_providers.crud module."""
import pytest
from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from auth.identity_providers import crud as idp_crud
from auth.identity_providers.schema import (
IdentityProviderCreate,
IdentityProviderUpdate,
)
from auth.identity_providers.models import IdentityProvider
class TestGetIdentityProvider:
"""Test suite for get_identity_provider function."""
def test_get_identity_provider_success(self, mock_db):
"""Test successfully retrieving an identity provider by ID.
Args:
mock_db: Mocked database session
Asserts:
- Correct database query is executed
- Identity provider is returned
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.id = 1
mock_idp.name = "Test Provider"
mock_db.query.return_value.filter.return_value.first.return_value = mock_idp
# Act
result = idp_crud.get_identity_provider(1, mock_db)
# Assert
assert result == mock_idp
mock_db.query.assert_called_once()
mock_db.query.return_value.filter.assert_called_once()
def test_get_identity_provider_not_found(self, mock_db):
"""Test retrieving non-existent identity provider.
Args:
mock_db: Mocked database session
Asserts:
- None is returned when provider not found
"""
# Arrange
mock_db.query.return_value.filter.return_value.first.return_value = None
# Act
result = idp_crud.get_identity_provider(999, mock_db)
# Assert
assert result is None
def test_get_identity_provider_database_error(self, mock_db):
"""Test get_identity_provider handles database errors.
Args:
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised on database error
"""
# Arrange
mock_db.query.side_effect = Exception("Database error")
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.get_identity_provider(1, mock_db)
assert exc_info.value.status_code == 500
assert "Internal Server Error" in str(exc_info.value.detail)
class TestGetIdentityProviderBySlug:
"""Test suite for get_identity_provider_by_slug function."""
def test_get_identity_provider_by_slug_success(self, mock_db):
"""Test successfully retrieving an identity provider by slug.
Args:
mock_db: Mocked database session
Asserts:
- Correct database query is executed
- Identity provider is returned
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.slug = "test-provider"
mock_db.query.return_value.filter.return_value.first.return_value = mock_idp
# Act
result = idp_crud.get_identity_provider_by_slug("test-provider", mock_db)
# Assert
assert result == mock_idp
mock_db.query.assert_called_once()
def test_get_identity_provider_by_slug_not_found(self, mock_db):
"""Test retrieving non-existent identity provider by slug.
Args:
mock_db: Mocked database session
Asserts:
- None is returned when provider not found
"""
# Arrange
mock_db.query.return_value.filter.return_value.first.return_value = None
# Act
result = idp_crud.get_identity_provider_by_slug("nonexistent", mock_db)
# Assert
assert result is None
def test_get_identity_provider_by_slug_database_error(self, mock_db):
"""Test get_identity_provider_by_slug handles database errors.
Args:
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised on database error
"""
# Arrange
mock_db.query.side_effect = Exception("Database error")
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.get_identity_provider_by_slug("test", mock_db)
assert exc_info.value.status_code == 500
class TestGetAllIdentityProviders:
"""Test suite for get_all_identity_providers function."""
def test_get_all_identity_providers_success(self, mock_db):
"""Test successfully retrieving all identity providers.
Args:
mock_db: Mocked database session
Asserts:
- All providers are returned ordered by name
"""
# Arrange
mock_idps = [MagicMock(spec=IdentityProvider) for _ in range(3)]
mock_db.query.return_value.order_by.return_value.all.return_value = mock_idps
# Act
result = idp_crud.get_all_identity_providers(mock_db)
# Assert
assert result == mock_idps
assert len(result) == 3
mock_db.query.assert_called_once()
mock_db.query.return_value.order_by.assert_called_once()
def test_get_all_identity_providers_empty(self, mock_db):
"""Test retrieving all identity providers when none exist.
Args:
mock_db: Mocked database session
Asserts:
- Empty list is returned
"""
# Arrange
mock_db.query.return_value.order_by.return_value.all.return_value = []
# Act
result = idp_crud.get_all_identity_providers(mock_db)
# Assert
assert result == []
def test_get_all_identity_providers_database_error(self, mock_db):
"""Test get_all_identity_providers handles database errors.
Args:
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised on database error
"""
# Arrange
mock_db.query.side_effect = Exception("Database error")
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.get_all_identity_providers(mock_db)
assert exc_info.value.status_code == 500
class TestGetEnabledProviders:
"""Test suite for get_enabled_providers function."""
def test_get_enabled_providers_success(self, mock_db):
"""Test successfully retrieving enabled identity providers.
Args:
mock_db: Mocked database session
Asserts:
- Only enabled providers are returned
- Results are ordered by name
"""
# Arrange
mock_idps = [MagicMock(spec=IdentityProvider) for _ in range(2)]
(
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value
) = mock_idps
# Act
result = idp_crud.get_enabled_providers(mock_db)
# Assert
assert result == mock_idps
assert len(result) == 2
mock_db.query.assert_called_once()
mock_db.query.return_value.filter.assert_called_once()
def test_get_enabled_providers_empty(self, mock_db):
"""Test retrieving enabled providers when none exist.
Args:
mock_db: Mocked database session
Asserts:
- Empty list is returned
"""
# Arrange
(
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value
) = []
# Act
result = idp_crud.get_enabled_providers(mock_db)
# Assert
assert result == []
def test_get_enabled_providers_database_error(self, mock_db):
"""Test get_enabled_providers handles database errors.
Args:
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised on database error
"""
# Arrange
mock_db.query.side_effect = Exception("Database error")
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.get_enabled_providers(mock_db)
assert exc_info.value.status_code == 500
class TestCreateIdentityProvider:
"""Test suite for create_identity_provider function."""
@patch("auth.identity_providers.crud.idp_models.IdentityProvider")
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
def test_create_identity_provider_success(
self, mock_get_by_slug, mock_encrypt, mock_idp_model, mock_db
):
"""Test successfully creating a new identity provider.
Args:
mock_get_by_slug: Mocked get_identity_provider_by_slug function
mock_encrypt: Mocked encryption function
mock_idp_model: Mocked IdentityProvider model
mock_db: Mocked database session
Asserts:
- Identity provider is created with encrypted credentials
- Database commit and refresh are called
"""
# Arrange
mock_get_by_slug.return_value = None # No existing provider
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
mock_idp_instance = MagicMock()
mock_idp_instance.id = 1
mock_idp_instance.name = "Test Provider"
mock_idp_model.return_value = mock_idp_instance
idp_data = IdentityProviderCreate(
name="Test Provider",
slug="test-provider",
client_id="test-client-id",
client_secret="test-secret",
)
# Act
result = idp_crud.create_identity_provider(idp_data, mock_db)
# Assert
mock_get_by_slug.assert_called_once_with("test-provider", mock_db)
assert mock_encrypt.call_count == 2 # Called for client_id and client_secret
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
mock_db.refresh.assert_called_once()
assert result == mock_idp_instance
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
def test_create_identity_provider_slug_exists(self, mock_get_by_slug, mock_db):
"""Test creating identity provider with existing slug.
Args:
mock_get_by_slug: Mocked get_identity_provider_by_slug function
mock_db: Mocked database session
Asserts:
- HTTPException with 409 status is raised
- Error message indicates slug conflict
"""
# Arrange
mock_existing = MagicMock(spec=IdentityProvider)
mock_get_by_slug.return_value = mock_existing
idp_data = IdentityProviderCreate(
name="Test Provider",
slug="existing-slug",
client_id="test-client-id",
client_secret="test-secret",
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.create_identity_provider(idp_data, mock_db)
assert exc_info.value.status_code == 409
assert "already exists" in str(exc_info.value.detail)
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
def test_create_identity_provider_database_error(
self, mock_get_by_slug, mock_encrypt, mock_db
):
"""Test create_identity_provider handles database errors.
Args:
mock_get_by_slug: Mocked get_identity_provider_by_slug function
mock_encrypt: Mocked encryption function
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised
- Database rollback is called
"""
# Arrange
mock_get_by_slug.return_value = None
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
mock_db.commit.side_effect = Exception("Database error")
idp_data = IdentityProviderCreate(
name="Test",
slug="test",
client_id="client-id",
client_secret="secret",
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.create_identity_provider(idp_data, mock_db)
assert exc_info.value.status_code == 500
mock_db.rollback.assert_called_once()
class TestUpdateIdentityProvider:
"""Test suite for update_identity_provider function."""
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
@patch("auth.identity_providers.crud.get_identity_provider")
def test_update_identity_provider_success(
self, mock_get, mock_get_by_slug, mock_encrypt, mock_db
):
"""Test successfully updating an identity provider.
Args:
mock_get: Mocked get_identity_provider function
mock_get_by_slug: Mocked get_identity_provider_by_slug function
mock_encrypt: Mocked encryption function
mock_db: Mocked database session
Asserts:
- Identity provider is updated
- Encrypted fields are encrypted
- Database commit and refresh are called
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.id = 1
mock_idp.slug = "original-slug"
mock_get.return_value = mock_idp
mock_get_by_slug.return_value = None
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
idp_data = IdentityProviderUpdate(
name="Updated Provider",
slug="updated-slug",
client_id="new-client-id",
client_secret="new-secret",
)
# Act
result = idp_crud.update_identity_provider(1, idp_data, mock_db)
# Assert
mock_get.assert_called_once_with(1, mock_db)
mock_get_by_slug.assert_called_once_with("updated-slug", mock_db)
assert mock_encrypt.call_count == 2
mock_db.commit.assert_called_once()
mock_db.refresh.assert_called_once()
@patch("auth.identity_providers.crud.get_identity_provider")
def test_update_identity_provider_not_found(self, mock_get, mock_db):
"""Test updating non-existent identity provider.
Args:
mock_get: Mocked get_identity_provider function
mock_db: Mocked database session
Asserts:
- HTTPException with 404 status is raised
"""
# Arrange
mock_get.return_value = None
idp_data = IdentityProviderUpdate(
name="Updated", slug="updated-slug", client_id="client-123"
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.update_identity_provider(999, idp_data, mock_db)
assert exc_info.value.status_code == 404
assert "not found" in str(exc_info.value.detail)
@patch("auth.identity_providers.crud.get_identity_provider_by_slug")
@patch("auth.identity_providers.crud.get_identity_provider")
def test_update_identity_provider_slug_conflict(
self, mock_get, mock_get_by_slug, mock_db
):
"""Test updating identity provider with conflicting slug.
Args:
mock_get: Mocked get_identity_provider function
mock_get_by_slug: Mocked get_identity_provider_by_slug function
mock_db: Mocked database session
Asserts:
- HTTPException with 409 status is raised
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.slug = "original-slug"
mock_get.return_value = mock_idp
mock_existing = MagicMock(spec=IdentityProvider)
mock_existing.slug = "existing-slug"
mock_get_by_slug.return_value = mock_existing
idp_data = IdentityProviderUpdate(
name="Test", slug="existing-slug", client_id="client-123"
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.update_identity_provider(1, idp_data, mock_db)
assert exc_info.value.status_code == 409
assert "already exists" in str(exc_info.value.detail)
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
@patch("auth.identity_providers.crud.get_identity_provider")
def test_update_identity_provider_without_slug_change(
self, mock_get, mock_encrypt, mock_db
):
"""Test updating identity provider without changing slug.
Args:
mock_get: Mocked get_identity_provider function
mock_encrypt: Mocked encryption function
mock_db: Mocked database session
Asserts:
- get_identity_provider_by_slug is not called
- Update succeeds
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.slug = "test-slug"
mock_get.return_value = mock_idp
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
idp_data = IdentityProviderUpdate(
name="Updated Name", slug="test-slug", client_id="client-123"
)
# Act
result = idp_crud.update_identity_provider(1, idp_data, mock_db)
# Assert
mock_get.assert_called_once_with(1, mock_db)
mock_db.commit.assert_called_once()
@patch("auth.identity_providers.crud.core_cryptography.encrypt_token_fernet")
@patch("auth.identity_providers.crud.get_identity_provider")
def test_update_identity_provider_database_error(
self, mock_get, mock_encrypt, mock_db
):
"""Test update_identity_provider handles database errors.
Args:
mock_get: Mocked get_identity_provider function
mock_encrypt: Mocked encryption function
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised
- Database rollback is called
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.slug = "test-slug"
mock_get.return_value = mock_idp
mock_encrypt.side_effect = lambda x: f"encrypted_{x}"
mock_db.commit.side_effect = Exception("Database error")
# Use same slug to avoid conflict check
idp_data = IdentityProviderUpdate(
name="Updated", slug="test-slug", client_id="client-123"
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.update_identity_provider(1, idp_data, mock_db)
assert exc_info.value.status_code == 500
mock_db.rollback.assert_called_once()
class TestDeleteIdentityProvider:
"""Test suite for delete_identity_provider function."""
@patch(
"auth.identity_providers.crud.user_identity_providers_crud.check_user_identity_providers_by_idp_id"
)
@patch("auth.identity_providers.crud.get_identity_provider")
def test_delete_identity_provider_success(
self, mock_get, mock_check_users, mock_db
):
"""Test successfully deleting an identity provider.
Args:
mock_get: Mocked get_identity_provider function
mock_check_users: Mocked check for user links
mock_db: Mocked database session
Asserts:
- Identity provider is deleted
- Database commit is called
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_idp.id = 1
mock_idp.name = "Test Provider"
mock_get.return_value = mock_idp
mock_check_users.return_value = None # No linked users
# Act
idp_crud.delete_identity_provider(1, mock_db)
# Assert
mock_get.assert_called_once_with(1, mock_db)
mock_check_users.assert_called_once_with(1, mock_db)
mock_db.delete.assert_called_once_with(mock_idp)
mock_db.commit.assert_called_once()
@patch("auth.identity_providers.crud.get_identity_provider")
def test_delete_identity_provider_not_found(self, mock_get, mock_db):
"""Test deleting non-existent identity provider.
Args:
mock_get: Mocked get_identity_provider function
mock_db: Mocked database session
Asserts:
- HTTPException with 404 status is raised
"""
# Arrange
mock_get.return_value = None
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.delete_identity_provider(999, mock_db)
assert exc_info.value.status_code == 404
assert "not found" in str(exc_info.value.detail)
@patch(
"auth.identity_providers.crud.user_identity_providers_crud.check_user_identity_providers_by_idp_id"
)
@patch("auth.identity_providers.crud.get_identity_provider")
def test_delete_identity_provider_with_linked_users(
self, mock_get, mock_check_users, mock_db
):
"""Test deleting identity provider with linked users.
Args:
mock_get: Mocked get_identity_provider function
mock_check_users: Mocked check for user links
mock_db: Mocked database session
Asserts:
- HTTPException with 409 status is raised
- Error message indicates linked users
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_get.return_value = mock_idp
mock_check_users.return_value = True # Has linked users
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.delete_identity_provider(1, mock_db)
assert exc_info.value.status_code == 409
assert "linked users" in str(exc_info.value.detail)
@patch(
"auth.identity_providers.crud.user_identity_providers_crud.check_user_identity_providers_by_idp_id"
)
@patch("auth.identity_providers.crud.get_identity_provider")
def test_delete_identity_provider_database_error(
self, mock_get, mock_check_users, mock_db
):
"""Test delete_identity_provider handles database errors.
Args:
mock_get: Mocked get_identity_provider function
mock_check_users: Mocked check for user links
mock_db: Mocked database session
Asserts:
- HTTPException with 500 status is raised
- Database rollback is called
"""
# Arrange
mock_idp = MagicMock(spec=IdentityProvider)
mock_get.return_value = mock_idp
mock_check_users.return_value = None
mock_db.delete.side_effect = Exception("Database error")
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
idp_crud.delete_identity_provider(1, mock_db)
assert exc_info.value.status_code == 500
mock_db.rollback.assert_called_once()

View File

@@ -0,0 +1,812 @@
"""Tests for identity_providers.schema module."""
import pytest
from pydantic import ValidationError
from datetime import datetime
from auth.identity_providers.schema import (
IdentityProviderBase,
IdentityProviderCreate,
IdentityProviderUpdate,
IdentityProvider,
IdentityProviderPublic,
IdentityProviderTemplate,
TokenExchangeRequest,
TokenExchangeResponse,
)
class TestIdentityProviderBase:
"""Test suite for IdentityProviderBase schema."""
def test_valid_identity_provider_base(self):
"""Test creating a valid IdentityProviderBase instance.
Asserts:
- Instance is created successfully
- All fields have correct values
- Default values are applied
"""
# Arrange & Act
idp = IdentityProviderBase(
name="Test Provider",
slug="test-provider",
provider_type="oidc",
enabled=True,
issuer_url="https://auth.example.com",
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
userinfo_endpoint="https://auth.example.com/userinfo",
jwks_uri="https://auth.example.com/jwks",
scopes="openid profile email",
icon="test-icon",
auto_create_users=True,
sync_user_info=True,
user_mapping={"username": ["preferred_username"], "email": ["email"]},
client_id="test-client-id",
)
# Assert
assert idp.name == "Test Provider"
assert idp.slug == "test-provider"
assert idp.provider_type == "oidc"
assert idp.enabled is True
assert idp.issuer_url == "https://auth.example.com"
assert idp.scopes == "openid profile email"
assert idp.auto_create_users is True
assert idp.sync_user_info is True
assert idp.client_id == "test-client-id"
def test_identity_provider_base_with_defaults(self):
"""Test IdentityProviderBase with default values.
Asserts:
- Default provider_type is 'oidc'
- Default enabled is False
- Default scopes is 'openid profile email'
- Default auto_create_users is True
- Default sync_user_info is True
"""
# Arrange & Act
idp = IdentityProviderBase(name="Test", slug="test", client_id="client-123")
# Assert
assert idp.provider_type == "oidc"
assert idp.enabled is False
assert idp.scopes == "openid profile email"
assert idp.auto_create_users is True
assert idp.sync_user_info is True
def test_slug_validation_lowercase_alphanumeric_hyphens(self):
"""Test slug validation accepts valid characters.
Asserts:
- Lowercase letters, numbers, and hyphens are accepted
"""
# Arrange & Act
idp = IdentityProviderBase(
name="Test", slug="test-provider-123", client_id="client-123"
)
# Assert
assert idp.slug == "test-provider-123"
def test_slug_validation_uppercase_rejected(self):
"""Test slug validation rejects uppercase letters.
Asserts:
- ValidationError is raised for uppercase characters
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderBase(
name="Test", slug="Test-Provider", client_id="client-123"
)
assert "Slug must contain only lowercase letters" in str(exc_info.value)
def test_slug_validation_special_characters_rejected(self):
"""Test slug validation rejects special characters.
Asserts:
- ValidationError is raised for special characters
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderBase(
name="Test", slug="test_provider", client_id="client-123"
)
assert "Slug must contain only lowercase letters" in str(exc_info.value)
def test_provider_type_validation_oidc(self):
"""Test provider_type validation accepts 'oidc'.
Asserts:
- 'oidc' is accepted as valid provider type
"""
# Arrange & Act
idp = IdentityProviderBase(
name="Test", slug="test", provider_type="oidc", client_id="client-123"
)
# Assert
assert idp.provider_type == "oidc"
def test_provider_type_validation_oauth2(self):
"""Test provider_type validation accepts 'oauth2'.
Asserts:
- 'oauth2' is accepted as valid provider type
"""
# Arrange & Act
idp = IdentityProviderBase(
name="Test", slug="test", provider_type="oauth2", client_id="client-123"
)
# Assert
assert idp.provider_type == "oauth2"
def test_provider_type_validation_saml(self):
"""Test provider_type validation accepts 'saml'.
Asserts:
- 'saml' is accepted as valid provider type
"""
# Arrange & Act
idp = IdentityProviderBase(
name="Test", slug="test", provider_type="saml", client_id="client-123"
)
# Assert
assert idp.provider_type == "saml"
def test_provider_type_validation_invalid(self):
"""Test provider_type validation rejects invalid types.
Asserts:
- ValidationError is raised for invalid provider type
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderBase(
name="Test",
slug="test",
provider_type="invalid",
client_id="client-123",
)
assert "Provider type must be one of" in str(exc_info.value)
def test_name_max_length_validation(self):
"""Test name field max length validation.
Asserts:
- ValidationError is raised when name exceeds 100 characters
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderBase(name="x" * 101, slug="test", client_id="client-123")
assert "String should have at most 100 characters" in str(exc_info.value)
def test_name_min_length_validation(self):
"""Test name field min length validation.
Asserts:
- ValidationError is raised when name is empty
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderBase(name="", slug="test", client_id="client-123")
assert "String should have at least 1 character" in str(exc_info.value)
def test_slug_max_length_validation(self):
"""Test slug field max length validation.
Asserts:
- ValidationError is raised when slug exceeds 50 characters
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderBase(name="Test", slug="x" * 51, client_id="client-123")
assert "String should have at most 50 characters" in str(exc_info.value)
def test_optional_fields_can_be_none(self):
"""Test optional fields can be None.
Asserts:
- issuer_url, authorization_endpoint, token_endpoint can be None
- userinfo_endpoint, jwks_uri, icon can be None
- user_mapping can be None
"""
# Arrange & Act
idp = IdentityProviderBase(
name="Test",
slug="test",
issuer_url=None,
authorization_endpoint=None,
token_endpoint=None,
userinfo_endpoint=None,
jwks_uri=None,
icon=None,
user_mapping=None,
client_id="client-123",
)
# Assert
assert idp.issuer_url is None
assert idp.authorization_endpoint is None
assert idp.token_endpoint is None
assert idp.userinfo_endpoint is None
assert idp.jwks_uri is None
assert idp.icon is None
assert idp.user_mapping is None
class TestIdentityProviderCreate:
"""Test suite for IdentityProviderCreate schema."""
def test_valid_identity_provider_create(self):
"""Test creating a valid IdentityProviderCreate instance.
Asserts:
- Instance is created successfully
- client_secret is required and set
"""
# Arrange & Act
idp = IdentityProviderCreate(
name="Test Provider",
slug="test-provider",
client_id="test-client-id",
client_secret="test-client-secret",
)
# Assert
assert idp.name == "Test Provider"
assert idp.slug == "test-provider"
assert idp.client_id == "test-client-id"
assert idp.client_secret == "test-client-secret"
def test_client_secret_required(self):
"""Test client_secret is required for IdentityProviderCreate.
Asserts:
- ValidationError is raised when client_secret is missing
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderCreate(name="Test", slug="test", client_id="client-123")
assert "client_secret" in str(exc_info.value)
assert "Field required" in str(exc_info.value)
def test_client_secret_min_length(self):
"""Test client_secret minimum length validation.
Asserts:
- ValidationError is raised when client_secret is empty
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderCreate(
name="Test",
slug="test",
client_id="client-123",
client_secret="",
)
assert "String should have at least 1 character" in str(exc_info.value)
def test_client_secret_max_length(self):
"""Test client_secret maximum length validation.
Asserts:
- ValidationError is raised when client_secret exceeds 512 characters
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderCreate(
name="Test",
slug="test",
client_id="client-123",
client_secret="x" * 513,
)
assert "String should have at most 512 characters" in str(exc_info.value)
class TestIdentityProviderUpdate:
"""Test suite for IdentityProviderUpdate schema."""
def test_valid_identity_provider_update(self):
"""Test creating a valid IdentityProviderUpdate instance.
Asserts:
- Instance is created successfully
- All fields can be updated
"""
# Arrange & Act
idp = IdentityProviderUpdate(
name="Updated Provider",
slug="updated-provider",
provider_type="oauth2",
enabled=True,
client_id="updated-client-id",
client_secret="updated-secret",
)
# Assert
assert idp.name == "Updated Provider"
assert idp.slug == "updated-provider"
assert idp.provider_type == "oauth2"
assert idp.enabled is True
assert idp.client_id == "updated-client-id"
assert idp.client_secret == "updated-secret"
def test_client_secret_optional(self):
"""Test client_secret is optional for IdentityProviderUpdate.
Asserts:
- Instance can be created without client_secret
- client_secret defaults to None
"""
# Arrange & Act
idp = IdentityProviderUpdate(name="Test", slug="test", client_id="client-123")
# Assert
assert idp.client_secret is None
def test_client_secret_can_be_none(self):
"""Test client_secret can explicitly be None.
Asserts:
- client_secret can be set to None
"""
# Arrange & Act
idp = IdentityProviderUpdate(
name="Test", slug="test", client_id="client-123", client_secret=None
)
# Assert
assert idp.client_secret is None
def test_client_secret_min_length_when_provided(self):
"""Test client_secret minimum length when provided.
Asserts:
- ValidationError is raised when client_secret is empty string
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderUpdate(
name="Test",
slug="test",
client_id="client-123",
client_secret="",
)
assert "String should have at least 1 character" in str(exc_info.value)
class TestIdentityProvider:
"""Test suite for IdentityProvider schema."""
def test_valid_identity_provider(self):
"""Test creating a valid IdentityProvider instance.
Asserts:
- Instance is created successfully
- All fields including id and timestamps are set
"""
# Arrange
now = datetime.utcnow()
# Act
idp = IdentityProvider(
id=1,
name="Test Provider",
slug="test-provider",
provider_type="oidc",
enabled=True,
client_id="test-client-id",
created_at=now,
updated_at=now,
)
# Assert
assert idp.id == 1
assert idp.name == "Test Provider"
assert idp.slug == "test-provider"
assert idp.provider_type == "oidc"
assert idp.enabled is True
assert idp.client_id == "test-client-id"
assert idp.created_at == now
assert idp.updated_at == now
def test_identity_provider_requires_id(self):
"""Test id field is required for IdentityProvider.
Asserts:
- ValidationError is raised when id is missing
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProvider(
name="Test",
slug="test",
client_id="client-123",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert "id" in str(exc_info.value)
assert "Field required" in str(exc_info.value)
def test_identity_provider_requires_timestamps(self):
"""Test created_at and updated_at are required.
Asserts:
- ValidationError is raised when timestamps are missing
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProvider(id=1, name="Test", slug="test", client_id="client-123")
assert "created_at" in str(exc_info.value) or "updated_at" in str(
exc_info.value
)
def test_serialize_client_id_decrypts_encrypted_value(self):
"""Test serialize_client_id decrypts encrypted client_id.
Asserts:
- Encrypted client_id (starting with 'gAAAAAB') triggers decryption
"""
# Arrange
now = datetime.utcnow()
encrypted_id = "gAAAAABtest_encrypted_token"
# Act
idp = IdentityProvider(
id=1,
name="Test",
slug="test",
client_id=encrypted_id,
created_at=now,
updated_at=now,
)
# Assert - This will be tested with actual encryption in integration tests
assert idp.client_id == encrypted_id
def test_serialize_client_id_preserves_unencrypted_value(self):
"""Test serialize_client_id preserves non-encrypted client_id.
Asserts:
- Non-encrypted client_id is returned as-is
"""
# Arrange
now = datetime.utcnow()
plain_id = "plain-client-id"
# Act
idp = IdentityProvider(
id=1,
name="Test",
slug="test",
client_id=plain_id,
created_at=now,
updated_at=now,
)
# Assert
assert idp.client_id == plain_id
def test_serialize_client_id_handles_none(self):
"""Test serialize_client_id handles None client_id.
Asserts:
- None client_id returns None
"""
# Arrange
now = datetime.utcnow()
# Act
idp = IdentityProvider(
id=1,
name="Test",
slug="test",
client_id=None,
created_at=now,
updated_at=now,
)
# Assert
assert idp.client_id is None
class TestIdentityProviderPublic:
"""Test suite for IdentityProviderPublic schema."""
def test_valid_identity_provider_public(self):
"""Test creating a valid IdentityProviderPublic instance.
Asserts:
- Instance is created successfully
- Only public fields are present
"""
# Arrange & Act
idp = IdentityProviderPublic(
id=1, name="Test Provider", slug="test-provider", icon="test-icon"
)
# Assert
assert idp.id == 1
assert idp.name == "Test Provider"
assert idp.slug == "test-provider"
assert idp.icon == "test-icon"
def test_identity_provider_public_icon_optional(self):
"""Test icon field is optional for IdentityProviderPublic.
Asserts:
- Instance can be created without icon
- icon defaults to None
"""
# Arrange & Act
idp = IdentityProviderPublic(id=1, name="Test", slug="test")
# Assert
assert idp.icon is None
def test_identity_provider_public_no_sensitive_fields(self):
"""Test IdentityProviderPublic doesn't have sensitive fields.
Asserts:
- client_id, client_secret are not present
"""
# Arrange & Act
idp = IdentityProviderPublic(id=1, name="Test", slug="test")
# Assert
assert not hasattr(idp, "client_id")
assert not hasattr(idp, "client_secret")
class TestIdentityProviderTemplate:
"""Test suite for IdentityProviderTemplate schema."""
def test_valid_identity_provider_template(self):
"""Test creating a valid IdentityProviderTemplate instance.
Asserts:
- Instance is created successfully
- All fields are set correctly
"""
# Arrange & Act
template = IdentityProviderTemplate(
template_id="keycloak",
name="Keycloak",
provider_type="oidc",
issuer_url="https://keycloak.example.com/realms/master",
scopes="openid profile email",
icon="keycloak",
user_mapping={"username": ["preferred_username"], "email": ["email"]},
description="Keycloak OIDC provider",
configuration_notes="Setup instructions here",
)
# Assert
assert template.template_id == "keycloak"
assert template.name == "Keycloak"
assert template.provider_type == "oidc"
assert template.issuer_url == "https://keycloak.example.com/realms/master"
assert template.scopes == "openid profile email"
assert template.icon == "keycloak"
assert template.user_mapping == {
"username": ["preferred_username"],
"email": ["email"],
}
assert template.description == "Keycloak OIDC provider"
assert template.configuration_notes == "Setup instructions here"
def test_template_required_fields(self):
"""Test required fields for IdentityProviderTemplate.
Asserts:
- ValidationError is raised when required fields are missing
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
IdentityProviderTemplate()
error_str = str(exc_info.value)
assert "template_id" in error_str or "Field required" in error_str
def test_template_optional_fields(self):
"""Test optional fields for IdentityProviderTemplate.
Asserts:
- Template can be created with only required fields
- Optional fields default to None
"""
# Arrange & Act
template = IdentityProviderTemplate(
template_id="custom",
name="Custom Provider",
provider_type="oidc",
scopes="openid",
description="Custom provider",
)
# Assert
assert template.issuer_url is None
assert template.icon is None
assert template.user_mapping is None
assert template.configuration_notes is None
class TestTokenExchangeRequest:
"""Test suite for TokenExchangeRequest schema."""
def test_valid_token_exchange_request(self):
"""Test creating a valid TokenExchangeRequest instance.
Asserts:
- Instance is created successfully with valid code_verifier
"""
# Arrange
# Valid base64url string (43 characters minimum)
code_verifier = "a" * 43 + "B1-_" * 10
# Act
request = TokenExchangeRequest(code_verifier=code_verifier)
# Assert
assert request.code_verifier == code_verifier
def test_code_verifier_min_length(self):
"""Test code_verifier minimum length validation.
Asserts:
- ValidationError is raised when code_verifier is less than 43 chars
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
TokenExchangeRequest(code_verifier="a" * 42)
assert "String should have at least 43 characters" in str(exc_info.value)
def test_code_verifier_max_length(self):
"""Test code_verifier maximum length validation.
Asserts:
- ValidationError is raised when code_verifier exceeds 128 chars
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
TokenExchangeRequest(code_verifier="a" * 129)
assert "String should have at most 128 characters" in str(exc_info.value)
def test_code_verifier_format_validation_valid(self):
"""Test code_verifier format validation accepts valid base64url.
Asserts:
- Valid base64url characters (A-Z, a-z, 0-9, -, _) are accepted
"""
# Arrange
valid_verifier = (
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
)
# Act
request = TokenExchangeRequest(code_verifier=valid_verifier)
# Assert
assert request.code_verifier == valid_verifier
def test_code_verifier_format_validation_invalid(self):
"""Test code_verifier format validation rejects invalid characters.
Asserts:
- ValidationError is raised for non-base64url characters
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
TokenExchangeRequest(
code_verifier="a" * 43 + "!"
) # ! is not valid base64url
assert "code_verifier must be valid base64url" in str(exc_info.value)
def test_code_verifier_format_validation_special_chars(self):
"""Test code_verifier format validation rejects special characters.
Asserts:
- ValidationError is raised for special characters like +, =
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
TokenExchangeRequest(code_verifier="a" * 43 + "+")
assert "code_verifier must be valid base64url" in str(exc_info.value)
class TestTokenExchangeResponse:
"""Test suite for TokenExchangeResponse schema."""
def test_valid_token_exchange_response(self):
"""Test creating a valid TokenExchangeResponse instance.
Asserts:
- Instance is created successfully with all required fields
- Default values are applied
"""
# Arrange & Act
response = TokenExchangeResponse(
session_id="session-123",
access_token="access-token-xyz",
refresh_token="refresh-token-abc",
csrf_token="csrf-token-def",
)
# Assert
assert response.session_id == "session-123"
assert response.access_token == "access-token-xyz"
assert response.refresh_token == "refresh-token-abc"
assert response.csrf_token == "csrf-token-def"
assert response.expires_in == 900 # Default value
assert response.token_type == "Bearer" # Default value
def test_token_exchange_response_required_fields(self):
"""Test required fields for TokenExchangeResponse.
Asserts:
- ValidationError is raised when required fields are missing
"""
# Arrange & Act & Assert
with pytest.raises(ValidationError) as exc_info:
TokenExchangeResponse()
error_str = str(exc_info.value)
assert "session_id" in error_str or "access_token" in error_str
def test_token_exchange_response_custom_expires_in(self):
"""Test custom expires_in value for TokenExchangeResponse.
Asserts:
- Custom expires_in value can be set
"""
# Arrange & Act
response = TokenExchangeResponse(
session_id="session-123",
access_token="access-token-xyz",
refresh_token="refresh-token-abc",
csrf_token="csrf-token-def",
expires_in=1800,
)
# Assert
assert response.expires_in == 1800
def test_token_exchange_response_custom_token_type(self):
"""Test custom token_type value for TokenExchangeResponse.
Asserts:
- Custom token_type value can be set
"""
# Arrange & Act
response = TokenExchangeResponse(
session_id="session-123",
access_token="access-token-xyz",
refresh_token="refresh-token-abc",
csrf_token="csrf-token-def",
token_type="Custom",
)
# Assert
assert response.token_type == "Custom"

View File

@@ -0,0 +1,493 @@
"""Tests for identity_providers.utils module."""
import pytest
from fastapi import HTTPException
from auth.identity_providers.utils import (
validate_pkce_challenge,
validate_pkce_verifier,
_secure_compare,
get_idp_template,
get_idp_templates,
)
from auth.identity_providers.schema import IdentityProviderTemplate
class TestValidatePkceChallenge:
"""Test suite for validate_pkce_challenge function."""
def test_validate_pkce_challenge_success(self):
"""Test validating a valid PKCE challenge.
Asserts:
- Valid S256 challenge passes validation
"""
# Arrange
# Valid base64url-encoded SHA256 hash (43 chars minimum)
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge_method = "S256"
# Act & Assert (no exception should be raised)
validate_pkce_challenge(code_challenge, code_challenge_method)
def test_validate_pkce_challenge_plain_method_rejected(self):
"""Test PKCE plain method is rejected.
Asserts:
- HTTPException with 400 status is raised for 'plain' method
"""
# Arrange
code_challenge = "test_challenge"
code_challenge_method = "plain"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_challenge(code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "Only S256 PKCE method is supported" in str(exc_info.value.detail)
def test_validate_pkce_challenge_min_length_violation(self):
"""Test PKCE challenge minimum length validation.
Asserts:
- HTTPException is raised for challenge shorter than 43 chars
"""
# Arrange
code_challenge = "a" * 42 # Too short
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_challenge(code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "43-128 characters" in str(exc_info.value.detail)
def test_validate_pkce_challenge_max_length_violation(self):
"""Test PKCE challenge maximum length validation.
Asserts:
- HTTPException is raised for challenge longer than 128 chars
"""
# Arrange
code_challenge = "a" * 129 # Too long
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_challenge(code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "43-128 characters" in str(exc_info.value.detail)
def test_validate_pkce_challenge_invalid_characters(self):
"""Test PKCE challenge with invalid characters.
Asserts:
- HTTPException is raised for non-base64url characters
"""
# Arrange
code_challenge = "a" * 43 + "!@#$" # Invalid characters
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_challenge(code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "valid base64url" in str(exc_info.value.detail)
def test_validate_pkce_challenge_valid_base64url_characters(self):
"""Test PKCE challenge with all valid base64url characters.
Asserts:
- All base64url characters (A-Z, a-z, 0-9, -, _) are accepted
"""
# Arrange
code_challenge = (
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr_-" # 45 chars, valid
)
code_challenge_method = "S256"
# Act & Assert (no exception should be raised)
validate_pkce_challenge(code_challenge, code_challenge_method)
class TestValidatePkceVerifier:
"""Test suite for validate_pkce_verifier function."""
def test_validate_pkce_verifier_success(self):
"""Test validating a valid PKCE verifier.
Asserts:
- Valid verifier that matches challenge passes validation
"""
# Arrange
# Valid verifier (RFC 7636 compliant)
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# Corresponding S256 challenge
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge_method = "S256"
# Act & Assert (no exception should be raised)
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
def test_validate_pkce_verifier_min_length_violation(self):
"""Test PKCE verifier minimum length validation.
Asserts:
- HTTPException is raised for verifier shorter than 43 chars
"""
# Arrange
code_verifier = "a" * 42 # Too short
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "43-128 characters" in str(exc_info.value.detail)
def test_validate_pkce_verifier_max_length_violation(self):
"""Test PKCE verifier maximum length validation.
Asserts:
- HTTPException is raised for verifier longer than 128 chars
"""
# Arrange
code_verifier = "a" * 129 # Too long
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "43-128 characters" in str(exc_info.value.detail)
def test_validate_pkce_verifier_invalid_characters(self):
"""Test PKCE verifier with invalid characters.
Asserts:
- HTTPException is raised for non-base64url characters
"""
# Arrange
code_verifier = "a" * 43 + "!@#" # Invalid characters
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "valid base64url" in str(exc_info.value.detail)
def test_validate_pkce_verifier_wrong_method(self):
"""Test PKCE verifier with wrong challenge method.
Asserts:
- HTTPException is raised for non-S256 method
"""
# Arrange
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
code_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge_method = "plain"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "Only S256 PKCE method is supported" in str(exc_info.value.detail)
def test_validate_pkce_verifier_mismatch(self):
"""Test PKCE verifier that doesn't match challenge.
Asserts:
- HTTPException is raised for mismatched verifier
"""
# Arrange
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
code_challenge = "wrong_challenge_value_that_does_not_match_verifier"
code_challenge_method = "S256"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
validate_pkce_verifier(code_verifier, code_challenge, code_challenge_method)
assert exc_info.value.status_code == 400
assert "Invalid code_verifier" in str(exc_info.value.detail)
class TestSecureCompare:
"""Test suite for _secure_compare function."""
def test_secure_compare_equal_strings(self):
"""Test secure comparison of equal strings.
Asserts:
- Returns True for identical strings
"""
# Arrange
str1 = "test_string_123"
str2 = "test_string_123"
# Act
result = _secure_compare(str1, str2)
# Assert
assert result is True
def test_secure_compare_different_strings(self):
"""Test secure comparison of different strings.
Asserts:
- Returns False for different strings
"""
# Arrange
str1 = "test_string_123"
str2 = "test_string_456"
# Act
result = _secure_compare(str1, str2)
# Assert
assert result is False
def test_secure_compare_different_lengths(self):
"""Test secure comparison of strings with different lengths.
Asserts:
- Returns False for strings of different lengths
"""
# Arrange
str1 = "short"
str2 = "much_longer_string"
# Act
result = _secure_compare(str1, str2)
# Assert
assert result is False
def test_secure_compare_empty_strings(self):
"""Test secure comparison of empty strings.
Asserts:
- Returns True for both empty strings
"""
# Arrange
str1 = ""
str2 = ""
# Act
result = _secure_compare(str1, str2)
# Assert
assert result is True
def test_secure_compare_one_empty_string(self):
"""Test secure comparison with one empty string.
Asserts:
- Returns False when one string is empty
"""
# Arrange
str1 = "non_empty"
str2 = ""
# Act
result = _secure_compare(str1, str2)
# Assert
assert result is False
def test_secure_compare_case_sensitive(self):
"""Test secure comparison is case-sensitive.
Asserts:
- Returns False for strings differing only in case
"""
# Arrange
str1 = "TestString"
str2 = "teststring"
# Act
result = _secure_compare(str1, str2)
# Assert
assert result is False
class TestGetIdpTemplate:
"""Test suite for get_idp_template function."""
def test_get_idp_template_keycloak(self):
"""Test retrieving Keycloak IdP template.
Asserts:
- Keycloak template is returned with correct structure
"""
# Act
template = get_idp_template("keycloak")
# Assert
assert template is not None
assert template["name"] == "Keycloak"
assert template["provider_type"] == "oidc"
assert "issuer_url" in template
assert "scopes" in template
assert "user_mapping" in template
def test_get_idp_template_authentik(self):
"""Test retrieving Authentik IdP template.
Asserts:
- Authentik template is returned with correct structure
"""
# Act
template = get_idp_template("authentik")
# Assert
assert template is not None
assert template["name"] == "Authentik"
assert template["provider_type"] == "oidc"
def test_get_idp_template_authelia(self):
"""Test retrieving Authelia IdP template.
Asserts:
- Authelia template is returned with correct structure
"""
# Act
template = get_idp_template("authelia")
# Assert
assert template is not None
assert template["name"] == "Authelia"
assert template["provider_type"] == "oidc"
def test_get_idp_template_casdoor(self):
"""Test retrieving Casdoor IdP template.
Asserts:
- Casdoor template is returned with correct structure
"""
# Act
template = get_idp_template("casdoor")
# Assert
assert template is not None
assert template["name"] == "Casdoor"
assert template["provider_type"] == "oidc"
def test_get_idp_template_pocketid(self):
"""Test retrieving Pocket ID template.
Asserts:
- Pocket ID template is returned with correct structure
"""
# Act
template = get_idp_template("pocketid")
# Assert
assert template is not None
assert template["name"] == "Pocket ID"
assert template["provider_type"] == "oidc"
def test_get_idp_template_nonexistent(self):
"""Test retrieving non-existent IdP template.
Asserts:
- None is returned for non-existent template
"""
# Act
template = get_idp_template("nonexistent_provider")
# Assert
assert template is None
class TestGetIdpTemplates:
"""Test suite for get_idp_templates function."""
def test_get_idp_templates_returns_list(self):
"""Test get_idp_templates returns a list.
Asserts:
- Returns a list of IdentityProviderTemplate objects
"""
# Act
templates = get_idp_templates()
# Assert
assert isinstance(templates, list)
assert len(templates) > 0
assert all(isinstance(t, IdentityProviderTemplate) for t in templates)
def test_get_idp_templates_contains_expected_providers(self):
"""Test get_idp_templates contains expected providers.
Asserts:
- List contains Keycloak, Authentik, Authelia, Casdoor, Pocket ID
"""
# Act
templates = get_idp_templates()
# Assert
template_ids = [t.template_id for t in templates]
assert "keycloak" in template_ids
assert "authentik" in template_ids
assert "authelia" in template_ids
assert "casdoor" in template_ids
assert "pocketid" in template_ids
def test_get_idp_templates_structure(self):
"""Test each template has required structure.
Asserts:
- Each template has required fields
"""
# Act
templates = get_idp_templates()
# Assert
for template in templates:
assert hasattr(template, "template_id")
assert hasattr(template, "name")
assert hasattr(template, "provider_type")
assert hasattr(template, "scopes")
assert hasattr(template, "description")
def test_get_idp_templates_all_oidc(self):
"""Test all templates are OIDC providers.
Asserts:
- All templates have provider_type 'oidc'
"""
# Act
templates = get_idp_templates()
# Assert
for template in templates:
assert template.provider_type == "oidc"
def test_get_idp_templates_has_user_mapping(self):
"""Test templates include user mapping configuration.
Asserts:
- Each template has user_mapping with username and email
"""
# Act
templates = get_idp_templates()
# Assert
for template in templates:
assert template.user_mapping is not None
assert "username" in template.user_mapping
assert "email" in template.user_mapping

View File

@@ -0,0 +1,181 @@
"""Tests for auth.constants module."""
import pytest
import auth.constants as auth_constants
class TestJWTConstants:
"""Test JWT configuration constants."""
def test_jwt_algorithm_is_set(self):
"""Test that JWT algorithm is defined and is a valid algorithm."""
assert auth_constants.JWT_ALGORITHM is not None
assert isinstance(auth_constants.JWT_ALGORITHM, str)
assert auth_constants.JWT_ALGORITHM in ["HS256", "HS384", "HS512"]
def test_access_token_expiry_is_positive(self):
"""Test that access token expiry is a positive integer."""
assert auth_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES > 0
assert isinstance(auth_constants.JWT_ACCESS_TOKEN_EXPIRE_MINUTES, int)
def test_refresh_token_expiry_is_positive(self):
"""Test that refresh token expiry is a positive integer."""
assert auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS > 0
assert isinstance(auth_constants.JWT_REFRESH_TOKEN_EXPIRE_DAYS, int)
def test_secret_key_is_set(self):
"""Test that SECRET_KEY is configured."""
assert auth_constants.JWT_SECRET_KEY is not None
assert isinstance(auth_constants.JWT_SECRET_KEY, str)
assert len(auth_constants.JWT_SECRET_KEY) >= 32
def test_session_timeout_constants(self):
"""Test that session timeout constants are valid."""
assert isinstance(auth_constants.SESSION_IDLE_TIMEOUT_ENABLED, bool)
assert auth_constants.SESSION_IDLE_TIMEOUT_HOURS >= 0
assert isinstance(auth_constants.SESSION_IDLE_TIMEOUT_HOURS, int)
assert auth_constants.SESSION_ABSOLUTE_TIMEOUT_HOURS > 0
assert isinstance(auth_constants.SESSION_ABSOLUTE_TIMEOUT_HOURS, int)
class TestScopeConstants:
"""Test scope configuration constants."""
def test_users_regular_scope_defined(self):
"""Test that users regular scope is properly defined."""
assert auth_constants.USERS_REGULAR_SCOPE is not None
assert isinstance(auth_constants.USERS_REGULAR_SCOPE, tuple)
assert "profile" in auth_constants.USERS_REGULAR_SCOPE
assert "users:read" in auth_constants.USERS_REGULAR_SCOPE
def test_users_admin_scope_defined(self):
"""Test that users admin scope is properly defined."""
assert auth_constants.USERS_ADMIN_SCOPE is not None
assert isinstance(auth_constants.USERS_ADMIN_SCOPE, tuple)
assert "users:write" in auth_constants.USERS_ADMIN_SCOPE
assert "sessions:read" in auth_constants.USERS_ADMIN_SCOPE
assert "sessions:write" in auth_constants.USERS_ADMIN_SCOPE
def test_gears_scope_defined(self):
"""Test that gears scope is properly defined."""
assert auth_constants.GEARS_SCOPE is not None
assert isinstance(auth_constants.GEARS_SCOPE, tuple)
assert "gears:read" in auth_constants.GEARS_SCOPE
assert "gears:write" in auth_constants.GEARS_SCOPE
def test_activities_scope_defined(self):
"""Test that activities scope is properly defined."""
assert auth_constants.ACTIVITIES_SCOPE is not None
assert isinstance(auth_constants.ACTIVITIES_SCOPE, tuple)
assert "activities:read" in auth_constants.ACTIVITIES_SCOPE
assert "activities:write" in auth_constants.ACTIVITIES_SCOPE
def test_health_scope_defined(self):
"""Test that health scope is properly defined."""
assert auth_constants.HEALTH_SCOPE is not None
assert isinstance(auth_constants.HEALTH_SCOPE, tuple)
assert "health:read" in auth_constants.HEALTH_SCOPE
assert "health:write" in auth_constants.HEALTH_SCOPE
assert "health_targets:read" in auth_constants.HEALTH_SCOPE
assert "health_targets:write" in auth_constants.HEALTH_SCOPE
def test_identity_providers_scope_defined(self):
"""Test that identity providers scopes are properly defined."""
assert auth_constants.IDENTITY_PROVIDERS_REGULAR_SCOPE is not None
assert isinstance(auth_constants.IDENTITY_PROVIDERS_REGULAR_SCOPE, tuple)
assert (
"identity_providers:read" in auth_constants.IDENTITY_PROVIDERS_REGULAR_SCOPE
)
assert auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE is not None
assert isinstance(auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE, tuple)
assert (
"identity_providers:write" in auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE
)
def test_server_settings_scope_defined(self):
"""Test that server settings scopes are properly defined."""
assert auth_constants.SERVER_SETTINGS_REGULAR_SCOPE is not None
assert isinstance(auth_constants.SERVER_SETTINGS_REGULAR_SCOPE, tuple)
assert auth_constants.SERVER_SETTINGS_ADMIN_SCOPE is not None
assert isinstance(auth_constants.SERVER_SETTINGS_ADMIN_SCOPE, tuple)
assert "server_settings:read" in auth_constants.SERVER_SETTINGS_ADMIN_SCOPE
assert "server_settings:write" in auth_constants.SERVER_SETTINGS_ADMIN_SCOPE
def test_scope_dict_contains_all_scopes(self):
"""Test that SCOPE_DICT contains descriptions for all defined scopes."""
assert auth_constants.SCOPE_DICT is not None
assert isinstance(auth_constants.SCOPE_DICT, dict)
# Check key scopes are documented
expected_scopes = [
"profile",
"users:read",
"users:write",
"gears:read",
"gears:write",
"activities:read",
"activities:write",
"health:read",
"health:write",
"server_settings:read",
"server_settings:write",
]
for scope in expected_scopes:
assert scope in auth_constants.SCOPE_DICT
assert isinstance(auth_constants.SCOPE_DICT[scope], str)
assert len(auth_constants.SCOPE_DICT[scope]) > 0
def test_regular_access_scope_composition(self):
"""Test that REGULAR_ACCESS_SCOPE includes all expected regular scopes."""
assert auth_constants.REGULAR_ACCESS_SCOPE is not None
assert isinstance(auth_constants.REGULAR_ACCESS_SCOPE, tuple)
# Should include users regular scope
for scope in auth_constants.USERS_REGULAR_SCOPE:
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
# Should include gears scope
for scope in auth_constants.GEARS_SCOPE:
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
# Should include activities scope
for scope in auth_constants.ACTIVITIES_SCOPE:
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
# Should include health scope
for scope in auth_constants.HEALTH_SCOPE:
assert scope in auth_constants.REGULAR_ACCESS_SCOPE
def test_admin_access_scope_composition(self):
"""Test that ADMIN_ACCESS_SCOPE includes all regular and admin scopes."""
assert auth_constants.ADMIN_ACCESS_SCOPE is not None
assert isinstance(auth_constants.ADMIN_ACCESS_SCOPE, tuple)
# Should include all regular scopes
for scope in auth_constants.REGULAR_ACCESS_SCOPE:
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
# Should include users admin scope
for scope in auth_constants.USERS_ADMIN_SCOPE:
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
# Should include identity providers admin scope
for scope in auth_constants.IDENTITY_PROVIDERS_ADMIN_SCOPE:
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
# Should include server settings admin scope
for scope in auth_constants.SERVER_SETTINGS_ADMIN_SCOPE:
assert scope in auth_constants.ADMIN_ACCESS_SCOPE
def test_admin_scope_is_superset_of_regular(self):
"""Test that admin scope contains all permissions from regular scope."""
regular_set = set(auth_constants.REGULAR_ACCESS_SCOPE)
admin_set = set(auth_constants.ADMIN_ACCESS_SCOPE)
assert regular_set.issubset(
admin_set
), "Admin scope should contain all regular scope permissions"

View File

@@ -0,0 +1,311 @@
"""
Tests for auth.schema module.
This module tests Pydantic schemas and dependency classes for authentication,
including login requests, MFA management, and failed attempt tracking.
"""
import pytest
from datetime import datetime, timedelta, timezone
from pydantic import ValidationError
import auth.schema as auth_schema
class TestLoginRequest:
"""Tests for LoginRequest Pydantic model."""
def test_login_request_valid(self):
"""Test valid login request."""
request = auth_schema.LoginRequest(username="testuser", password="Password1!")
assert request.username == "testuser"
assert request.password == "Password1!"
def test_login_request_username_too_short(self):
"""Test login request with empty username."""
with pytest.raises(ValidationError) as exc_info:
auth_schema.LoginRequest(username="", password="Password1!")
assert "username" in str(exc_info.value)
def test_login_request_username_too_long(self):
"""Test login request with username exceeding max length."""
with pytest.raises(ValidationError) as exc_info:
auth_schema.LoginRequest(username="a" * 251, password="Password1!")
assert "username" in str(exc_info.value)
def test_login_request_password_too_short(self):
"""Test login request with password less than 8 characters."""
with pytest.raises(ValidationError) as exc_info:
auth_schema.LoginRequest(username="testuser", password="Pass1!")
assert "password" in str(exc_info.value)
class TestMFALoginRequest:
"""Tests for MFALoginRequest Pydantic model."""
def test_mfa_login_request_valid(self):
"""Test valid MFA login request with 6-digit code."""
request = auth_schema.MFALoginRequest(username="testuser", mfa_code="123456")
assert request.username == "testuser"
assert request.mfa_code == "123456"
def test_mfa_login_request_invalid_code_format_letters(self):
"""Test MFA login request with non-numeric code."""
with pytest.raises(ValidationError) as exc_info:
auth_schema.MFALoginRequest(username="testuser", mfa_code="12345a")
assert "mfa_code" in str(exc_info.value)
def test_mfa_login_request_invalid_code_too_short(self):
"""Test MFA login request with code less than 6 digits."""
with pytest.raises(ValidationError) as exc_info:
auth_schema.MFALoginRequest(username="testuser", mfa_code="12345")
assert "mfa_code" in str(exc_info.value)
def test_mfa_login_request_invalid_code_too_long(self):
"""Test MFA login request with code more than 6 digits."""
with pytest.raises(ValidationError) as exc_info:
auth_schema.MFALoginRequest(username="testuser", mfa_code="1234567")
assert "mfa_code" in str(exc_info.value)
class TestMFARequiredResponse:
"""Tests for MFARequiredResponse Pydantic model."""
def test_mfa_required_response_defaults(self):
"""Test MFA required response with default values."""
response = auth_schema.MFARequiredResponse(username="testuser")
assert response.mfa_required is True
assert response.username == "testuser"
assert response.message == "MFA verification required"
def test_mfa_required_response_custom_message(self):
"""Test MFA required response with custom message."""
response = auth_schema.MFARequiredResponse(
username="testuser", message="Custom MFA message"
)
assert response.mfa_required is True
assert response.message == "Custom MFA message"
def test_mfa_required_response_explicit_false(self):
"""Test MFA required response with explicit False."""
response = auth_schema.MFARequiredResponse(
mfa_required=False, username="testuser"
)
assert response.mfa_required is False
class TestPendingMFALogin:
"""Tests for PendingMFALogin class."""
def test_add_and_get_pending_login(self):
"""Test adding and retrieving pending MFA login."""
store = auth_schema.PendingMFALogin()
store.add_pending_login("testuser", 123)
assert store.get_pending_login("testuser") == 123
def test_get_pending_login_not_found(self):
"""Test getting non-existent pending login returns None."""
store = auth_schema.PendingMFALogin()
assert store.get_pending_login("nonexistent") is None
def test_has_pending_login(self):
"""Test checking if username has pending login."""
store = auth_schema.PendingMFALogin()
store.add_pending_login("testuser", 123)
assert store.has_pending_login("testuser") is True
assert store.has_pending_login("nonexistent") is False
def test_delete_pending_login(self):
"""Test deleting pending login."""
store = auth_schema.PendingMFALogin()
store.add_pending_login("testuser", 123)
store.delete_pending_login("testuser")
assert store.get_pending_login("testuser") is None
def test_delete_nonexistent_pending_login(self):
"""Test deleting non-existent pending login doesn't raise error."""
store = auth_schema.PendingMFALogin()
store.delete_pending_login("nonexistent") # Should not raise
def test_clear_all(self):
"""Test clearing all pending logins."""
store = auth_schema.PendingMFALogin()
store.add_pending_login("user1", 1)
store.add_pending_login("user2", 2)
store.clear_all()
assert store.get_pending_login("user1") is None
assert store.get_pending_login("user2") is None
def test_is_not_locked_out_initially(self):
"""Test user is not locked out initially."""
store = auth_schema.PendingMFALogin()
assert store.is_locked_out("testuser") is False
def test_lockout_after_5_failures(self):
"""Test 5-minute lockout after 5 failed attempts."""
store = auth_schema.PendingMFALogin()
for _ in range(5):
store.record_failed_attempt("testuser")
assert store.is_locked_out("testuser") is True
lockout_time = store.get_lockout_time("testuser")
assert lockout_time is not None
assert lockout_time > datetime.now(timezone.utc)
def test_lockout_after_10_failures(self):
"""Test 30-minute lockout after 10 failed attempts."""
store = auth_schema.PendingMFALogin()
for _ in range(10):
store.record_failed_attempt("testuser")
assert store.is_locked_out("testuser") is True
def test_lockout_after_15_failures(self):
"""Test 2-hour lockout after 15 failed attempts."""
store = auth_schema.PendingMFALogin()
for _ in range(15):
store.record_failed_attempt("testuser")
assert store.is_locked_out("testuser") is True
def test_failed_attempt_count_doesnt_increment_while_locked(self):
"""Test failed attempt counter doesn't increment during lockout."""
store = auth_schema.PendingMFALogin()
for _ in range(5):
store.record_failed_attempt("testuser")
# Try to increment during lockout
count_before = store.record_failed_attempt("testuser")
count_after = store.record_failed_attempt("testuser")
assert count_before == count_after
def test_reset_failed_attempts(self):
"""Test resetting failed attempts on successful verification."""
store = auth_schema.PendingMFALogin()
store.record_failed_attempt("testuser")
store.record_failed_attempt("testuser")
store.reset_failed_attempts("testuser")
assert store.is_locked_out("testuser") is False
assert store.get_lockout_time("testuser") is None
def test_get_lockout_time_returns_none_when_not_locked(self):
"""Test get_lockout_time returns None when user not locked."""
store = auth_schema.PendingMFALogin()
assert store.get_lockout_time("testuser") is None
def test_clear_all_clears_failed_attempts(self):
"""Test clear_all() clears both pending logins and failed attempts."""
store = auth_schema.PendingMFALogin()
store.add_pending_login("testuser", 123)
for _ in range(5):
store.record_failed_attempt("testuser")
store.clear_all()
assert store.is_locked_out("testuser") is False
class TestFailedLoginAttempts:
"""Tests for FailedLoginAttempts class."""
def test_is_not_locked_out_initially(self):
"""Test user is not locked out initially."""
tracker = auth_schema.FailedLoginAttempts()
assert tracker.is_locked_out("testuser") is False
def test_lockout_after_5_failures(self):
"""Test 5-minute lockout after 5 failed login attempts."""
tracker = auth_schema.FailedLoginAttempts()
for _ in range(5):
tracker.record_failed_attempt("testuser")
assert tracker.is_locked_out("testuser") is True
lockout_time = tracker.get_lockout_time("testuser")
assert lockout_time is not None
assert lockout_time > datetime.now(timezone.utc)
def test_lockout_after_10_failures(self):
"""Test 30-minute lockout after 10 failed login attempts."""
tracker = auth_schema.FailedLoginAttempts()
for _ in range(10):
tracker.record_failed_attempt("testuser")
assert tracker.is_locked_out("testuser") is True
def test_lockout_after_20_failures(self):
"""Test 24-hour lockout after 20 failed login attempts."""
tracker = auth_schema.FailedLoginAttempts()
for _ in range(20):
tracker.record_failed_attempt("testuser")
assert tracker.is_locked_out("testuser") is True
def test_failed_attempt_count_returns_correctly(self):
"""Test record_failed_attempt returns current count."""
tracker = auth_schema.FailedLoginAttempts()
count1 = tracker.record_failed_attempt("testuser")
count2 = tracker.record_failed_attempt("testuser")
count3 = tracker.record_failed_attempt("testuser")
assert count1 == 1
assert count2 == 2
assert count3 == 3
def test_failed_attempt_count_doesnt_increment_while_locked(self):
"""Test failed attempt counter doesn't increment during lockout."""
tracker = auth_schema.FailedLoginAttempts()
for _ in range(5):
tracker.record_failed_attempt("testuser")
# Try to increment during lockout
count_before = tracker.record_failed_attempt("testuser")
count_after = tracker.record_failed_attempt("testuser")
assert count_before == count_after
def test_reset_attempts(self):
"""Test resetting failed attempts on successful login."""
tracker = auth_schema.FailedLoginAttempts()
tracker.record_failed_attempt("testuser")
tracker.record_failed_attempt("testuser")
tracker.reset_attempts("testuser")
assert tracker.is_locked_out("testuser") is False
assert tracker.get_lockout_time("testuser") is None
def test_get_lockout_time_returns_none_when_not_locked(self):
"""Test get_lockout_time returns None when user not locked."""
tracker = auth_schema.FailedLoginAttempts()
assert tracker.get_lockout_time("testuser") is None
def test_clear_all(self):
"""Test clearing all failed attempt records."""
tracker = auth_schema.FailedLoginAttempts()
tracker.record_failed_attempt("user1")
tracker.record_failed_attempt("user2")
tracker.clear_all()
assert tracker.is_locked_out("user1") is False
assert tracker.is_locked_out("user2") is False
def test_different_users_tracked_independently(self):
"""Test different users have independent failed attempt tracking."""
tracker = auth_schema.FailedLoginAttempts()
for _ in range(3):
tracker.record_failed_attempt("user1")
for _ in range(5):
tracker.record_failed_attempt("user2")
assert tracker.is_locked_out("user1") is False
assert tracker.is_locked_out("user2") is True
class TestDependencyFunctions:
"""Tests for dependency injection functions."""
def test_get_pending_mfa_store(self):
"""Test get_pending_mfa_store returns PendingMFALogin instance."""
store = auth_schema.get_pending_mfa_store()
assert isinstance(store, auth_schema.PendingMFALogin)
def test_get_pending_mfa_store_returns_singleton(self):
"""Test get_pending_mfa_store returns same instance."""
store1 = auth_schema.get_pending_mfa_store()
store2 = auth_schema.get_pending_mfa_store()
assert store1 is store2
def test_get_failed_login_attempts(self):
"""Test get_failed_login_attempts returns FailedLoginAttempts instance."""
tracker = auth_schema.get_failed_login_attempts()
assert isinstance(tracker, auth_schema.FailedLoginAttempts)
def test_get_failed_login_attempts_returns_singleton(self):
"""Test get_failed_login_attempts returns same instance."""
tracker1 = auth_schema.get_failed_login_attempts()
tracker2 = auth_schema.get_failed_login_attempts()
assert tracker1 is tracker2

View File

@@ -0,0 +1,287 @@
"""Tests for auth.security module."""
import pytest
from fastapi import HTTPException
from fastapi.security import SecurityScopes
import auth.security as auth_security
import auth.token_manager as auth_token_manager
class TestGetToken:
"""Test get_token function for token retrieval logic."""
def test_get_access_token_from_header(self):
"""Test access token retrieval from Authorization header."""
result = auth_security.get_token(
non_cookie_token="test_token",
cookie_token=None,
client_type="web",
token_type="access",
)
assert result == "test_token"
def test_get_access_token_missing_raises_error(self):
"""Test that missing access token raises 401."""
with pytest.raises(HTTPException) as exc_info:
auth_security.get_token(
non_cookie_token=None,
cookie_token=None,
client_type="web",
token_type="access",
)
assert exc_info.value.status_code == 401
assert "Access token missing" in exc_info.value.detail
def test_get_refresh_token_from_cookie_for_web(self):
"""Test refresh token retrieval from cookie for web client."""
result = auth_security.get_token(
non_cookie_token=None,
cookie_token="refresh_cookie_token",
client_type="web",
token_type="refresh",
)
assert result == "refresh_cookie_token"
def test_get_refresh_token_from_header_for_mobile(self):
"""Test refresh token retrieval from header for mobile client."""
result = auth_security.get_token(
non_cookie_token="refresh_header_token",
cookie_token=None,
client_type="mobile",
token_type="refresh",
)
assert result == "refresh_header_token"
def test_get_refresh_token_missing_for_web_raises_error(self):
"""Test that missing refresh token from cookie for web raises 401."""
with pytest.raises(HTTPException) as exc_info:
auth_security.get_token(
non_cookie_token=None,
cookie_token=None,
client_type="web",
token_type="refresh",
)
assert exc_info.value.status_code == 401
assert "Refresh token missing from cookie" in exc_info.value.detail
def test_get_refresh_token_missing_for_mobile_raises_error(self):
"""Test that missing refresh token from header for mobile raises 401."""
with pytest.raises(HTTPException) as exc_info:
auth_security.get_token(
non_cookie_token=None,
cookie_token=None,
client_type="mobile",
token_type="refresh",
)
assert exc_info.value.status_code == 401
assert (
"Refresh token missing from Authorization header" in exc_info.value.detail
)
def test_invalid_token_type_raises_error(self):
"""Test that invalid token type raises 403."""
with pytest.raises(HTTPException) as exc_info:
auth_security.get_token(
non_cookie_token="test_token",
cookie_token=None,
client_type="web",
token_type="invalid_type",
)
assert exc_info.value.status_code == 403
assert "Invalid client type or token type" in exc_info.value.detail
class TestAccessTokenValidation:
"""Test access token validation functions."""
def test_validate_access_token_success(self, token_manager, sample_user_read):
"""Test successful access token validation."""
# Create a valid token
_, access_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
)
# Should not raise an exception
try:
auth_security.validate_access_token(access_token, token_manager)
except HTTPException:
pytest.fail("Valid token should not raise HTTPException")
def test_validate_access_token_with_expired_token(self, token_manager):
"""Test that expired token raises HTTPException."""
expired_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzaWQiOiJzZXNzaW9uLWlkIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwiYXVkIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwic3ViIjoxLCJzY29wZSI6WyJwcm9maWxlIl0sImlhdCI6MTc1OTk1MzE4NSwibmJmIjoxNzU5OTUzMTg1LCJleHAiOjE3NTk5NTQwODUsImp0aSI6Ijc5ZjY0MmVkLTQ3M2QtNDEwZi1hYzI1LTIyNjEwNTlhMzg2MiJ9.VSizGzvIIi_EJYD_YmfZBEBE_9aJbhLW-25cD1kEOeM"
with pytest.raises(HTTPException) as exc_info:
auth_security.validate_access_token(expired_token, token_manager)
assert exc_info.value.status_code == 401
def test_validate_access_token_with_invalid_token(self, token_manager):
"""Test that invalid token raises HTTPException."""
invalid_token = "invalid.token.here"
with pytest.raises(HTTPException) as exc_info:
auth_security.validate_access_token(invalid_token, token_manager)
assert exc_info.value.status_code == 401
class TestGetSubFromAccessToken:
"""Test extracting user ID from access token."""
def test_get_sub_from_valid_token(self, token_manager, sample_user_read):
"""Test extracting user ID from valid access token."""
_, access_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
)
sub = auth_security.get_sub_from_access_token(access_token, token_manager)
assert sub == sample_user_read.id
assert isinstance(sub, int)
def test_get_sub_from_invalid_token_raises_error(self, token_manager):
"""Test that invalid token raises HTTPException."""
invalid_token = "invalid.token.here"
with pytest.raises(HTTPException) as exc_info:
auth_security.get_sub_from_access_token(invalid_token, token_manager)
assert exc_info.value.status_code == 401
class TestGetSidFromAccessToken:
"""Test extracting session ID from access token."""
def test_get_sid_from_valid_token(self, token_manager, sample_user_read):
"""Test extracting session ID from valid access token."""
session_id = "test-session-123"
_, access_token = token_manager.create_token(
session_id, sample_user_read, auth_token_manager.TokenType.ACCESS
)
sid = auth_security.get_sid_from_access_token(access_token, token_manager)
assert sid == session_id
assert isinstance(sid, str)
def test_get_sid_from_invalid_token_raises_error(self, token_manager):
"""Test that invalid token raises HTTPException."""
invalid_token = "invalid.token.here"
with pytest.raises(HTTPException) as exc_info:
auth_security.get_sid_from_access_token(invalid_token, token_manager)
assert exc_info.value.status_code == 401
class TestRefreshTokenValidation:
"""Test refresh token validation functions."""
def test_validate_refresh_token_success(self, token_manager, sample_user_read):
"""Test successful refresh token validation."""
_, refresh_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.REFRESH
)
# Should not raise an exception
try:
auth_security.validate_refresh_token(refresh_token, token_manager)
except HTTPException:
pytest.fail("Valid refresh token should not raise HTTPException")
def test_validate_refresh_token_with_expired_token(self, token_manager):
"""Test that expired refresh token raises HTTPException."""
expired_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzaWQiOiJzZXNzaW9uLWlkIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwiYXVkIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwIiwic3ViIjoxLCJzY29wZSI6WyJwcm9maWxlIl0sImlhdCI6MTc1OTk1MzE4NSwibmJmIjoxNzU5OTUzMTg1LCJleHAiOjE3NTk5NTQwODUsImp0aSI6Ijc5ZjY0MmVkLTQ3M2QtNDEwZi1hYzI1LTIyNjEwNTlhMzg2MiJ9.VSizGzvIIi_EJYD_YmfZBEBE_9aJbhLW-25cD1kEOeM"
with pytest.raises(HTTPException) as exc_info:
auth_security.validate_refresh_token(expired_token, token_manager)
assert exc_info.value.status_code == 401
class TestGetSubFromRefreshToken:
"""Test extracting user ID from refresh token."""
def test_get_sub_from_valid_refresh_token(self, token_manager, sample_user_read):
"""Test extracting user ID from valid refresh token."""
_, refresh_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.REFRESH
)
sub = auth_security.get_sub_from_refresh_token(refresh_token, token_manager)
assert sub == sample_user_read.id
assert isinstance(sub, int)
class TestGetSidFromRefreshToken:
"""Test extracting session ID from refresh token."""
def test_get_sid_from_valid_refresh_token(self, token_manager, sample_user_read):
"""Test extracting session ID from valid refresh token."""
session_id = "test-session-456"
_, refresh_token = token_manager.create_token(
session_id, sample_user_read, auth_token_manager.TokenType.REFRESH
)
sid = auth_security.get_sid_from_refresh_token(refresh_token, token_manager)
assert sid == session_id
assert isinstance(sid, str)
class TestCheckScopes:
"""Test scope validation function."""
def test_check_scopes_with_valid_scopes(self, token_manager, sample_user_read):
"""Test that valid scopes pass validation."""
_, access_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
)
security_scopes = SecurityScopes(scopes=["profile", "users:read"])
# Should not raise an exception
try:
auth_security.check_scopes(access_token, token_manager, security_scopes)
except HTTPException:
pytest.fail("Valid scopes should not raise HTTPException")
def test_check_scopes_with_missing_scope(self, token_manager, sample_user_read):
"""Test that missing required scope raises 403."""
_, access_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
)
# Request a scope that the user doesn't have
security_scopes = SecurityScopes(scopes=["admin:write"])
with pytest.raises(HTTPException) as exc_info:
auth_security.check_scopes(access_token, token_manager, security_scopes)
assert exc_info.value.status_code == 403
assert "Missing permissions" in exc_info.value.detail
def test_check_scopes_with_no_required_scopes(
self, token_manager, sample_user_read
):
"""Test that no required scopes passes validation."""
_, access_token = token_manager.create_token(
"session-id", sample_user_read, auth_token_manager.TokenType.ACCESS
)
security_scopes = SecurityScopes(scopes=[])
# Should not raise an exception
try:
auth_security.check_scopes(access_token, token_manager, security_scopes)
except HTTPException:
pytest.fail("Empty required scopes should not raise HTTPException")
class TestGetAndReturnTokens:
"""Test simple token return functions."""
def test_get_and_return_access_token(self):
"""Test that access token is returned unchanged."""
test_token = "test_access_token"
result = auth_security.get_and_return_access_token(test_token)
assert result == test_token
def test_get_and_return_refresh_token(self):
"""Test that refresh token is returned unchanged."""
test_token = "test_refresh_token"
result = auth_security.get_and_return_refresh_token(test_token)
assert result == test_token

View File

@@ -0,0 +1,398 @@
"""Tests for auth.utils module."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException, Response
import auth.utils as auth_utils
import auth.token_manager as auth_token_manager
import users.user.schema as user_schema
class TestAuthenticateUser:
"""Test user authentication function."""
def test_authenticate_user_success(
self, password_hasher, mock_db, sample_user_read
):
"""Test successful user authentication."""
# Arrange
username = "testuser"
password = "TestPassword123!"
# Create a user with hashed password
hashed_password = password_hasher.hash_password(password)
mock_user = MagicMock()
mock_user.id = sample_user_read.id
mock_user.password = hashed_password
mock_user.username = username
# Mock the CRUD function to return our user
with patch("auth.utils.users_crud.authenticate_user", return_value=mock_user):
# Act
result = auth_utils.authenticate_user(
username, password, password_hasher, mock_db
)
# Assert
assert result == mock_user
def test_authenticate_user_invalid_username(self, password_hasher, mock_db):
"""Test authentication with invalid username raises 401."""
with patch("auth.utils.users_crud.authenticate_user", return_value=None):
with pytest.raises(HTTPException) as exc_info:
auth_utils.authenticate_user(
"nonexistent", "password", password_hasher, mock_db
)
assert exc_info.value.status_code == 401
assert "Invalid username" in exc_info.value.detail
def test_authenticate_user_invalid_password(self, password_hasher, mock_db):
"""Test authentication with invalid password raises 401."""
# Arrange
username = "testuser"
correct_password = "CorrectPassword123!"
wrong_password = "WrongPassword123!"
hashed_password = password_hasher.hash_password(correct_password)
mock_user = MagicMock()
mock_user.password = hashed_password
with patch("auth.utils.users_crud.authenticate_user", return_value=mock_user):
with pytest.raises(HTTPException) as exc_info:
auth_utils.authenticate_user(
username, wrong_password, password_hasher, mock_db
)
assert exc_info.value.status_code == 401
assert "Invalid password" in exc_info.value.detail
def test_authenticate_user_updates_password_hash_if_needed(
self, password_hasher, mock_db, sample_user_read
):
"""Test that password hash is updated if algorithm changed."""
# Arrange
username = "testuser"
password = "TestPassword123!"
# Use a different hasher to simulate old hash
from pwdlib.hashers.bcrypt import BcryptHasher
old_hasher_instance = BcryptHasher()
old_hash = old_hasher_instance.hash(password)
mock_user = MagicMock()
mock_user.id = sample_user_read.id
mock_user.password = old_hash
mock_user.username = username
with patch("auth.utils.users_crud.authenticate_user", return_value=mock_user):
with patch("auth.utils.users_crud.edit_user_password") as mock_edit:
# Act
result = auth_utils.authenticate_user(
username, password, password_hasher, mock_db
)
# Assert
assert result == mock_user
# Password should be updated since we're using different hasher
mock_edit.assert_called_once()
class TestCreateTokens:
"""Test token creation function."""
def test_create_tokens_generates_all_tokens(self, token_manager, sample_user_read):
"""Test that create_tokens generates all required tokens."""
# Act
(
session_id,
access_token_exp,
access_token,
refresh_token_exp,
refresh_token,
csrf_token,
) = auth_utils.create_tokens(sample_user_read, token_manager)
# Assert
assert session_id is not None
assert isinstance(session_id, str)
assert len(session_id) > 0
assert access_token_exp is not None
assert isinstance(access_token_exp, datetime)
assert access_token_exp > datetime.now(timezone.utc)
assert access_token is not None
assert isinstance(access_token, str)
assert len(access_token) > 0
assert refresh_token_exp is not None
assert isinstance(refresh_token_exp, datetime)
assert refresh_token_exp > datetime.now(timezone.utc)
assert refresh_token is not None
assert isinstance(refresh_token, str)
assert len(refresh_token) > 0
assert csrf_token is not None
assert isinstance(csrf_token, str)
assert len(csrf_token) >= 32
def test_create_tokens_with_custom_session_id(
self, token_manager, sample_user_read
):
"""Test that create_tokens uses provided session ID."""
# Arrange
custom_session_id = "custom-session-123"
# Act
(
session_id,
_,
_,
_,
_,
_,
) = auth_utils.create_tokens(sample_user_read, token_manager, custom_session_id)
# Assert
assert session_id == custom_session_id
def test_create_tokens_refresh_expires_after_access(
self, token_manager, sample_user_read
):
"""Test that refresh token expires after access token."""
# Act
(
_,
access_token_exp,
_,
refresh_token_exp,
_,
_,
) = auth_utils.create_tokens(sample_user_read, token_manager)
# Assert
assert refresh_token_exp > access_token_exp
def test_create_tokens_generates_unique_tokens(
self, token_manager, sample_user_read
):
"""Test that multiple calls generate unique tokens."""
# Act
(_, _, access_token1, _, refresh_token1, csrf_token1) = (
auth_utils.create_tokens(sample_user_read, token_manager)
)
(_, _, access_token2, _, refresh_token2, csrf_token2) = (
auth_utils.create_tokens(sample_user_read, token_manager)
)
# Assert
assert access_token1 != access_token2
assert refresh_token1 != refresh_token2
assert csrf_token1 != csrf_token2
class TestCompleteLogin:
"""Test complete_login function."""
def test_complete_login_for_web_client(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test complete_login for web client sets cookies and returns tokens."""
# Arrange
response = Response()
client_type = "web"
with patch("auth.utils.session_utils.create_session"):
# Act
result = auth_utils.complete_login(
response,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
# Assert
assert "session_id" in result
assert "access_token" in result
assert "csrf_token" in result
assert "token_type" in result
assert "expires_in" in result
assert result["token_type"] == "bearer"
assert isinstance(result["expires_in"], int)
# Check that refresh token cookie was set
assert "endurain_refresh_token" in response.headers.get("set-cookie", "")
def test_complete_login_for_mobile_client(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test complete_login for mobile client returns tokens."""
# Arrange
response = Response()
client_type = "mobile"
with patch("auth.utils.session_utils.create_session"):
# Act
result = auth_utils.complete_login(
response,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
# Assert
assert "session_id" in result
assert "access_token" in result
assert "csrf_token" in result
assert result["token_type"] == "bearer"
def test_complete_login_creates_session(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test that complete_login creates a session in the database."""
# Arrange
response = Response()
client_type = "web"
with patch("auth.utils.session_utils.create_session") as mock_create_session:
# Act
result = auth_utils.complete_login(
response,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
# Assert
mock_create_session.assert_called_once()
call_args = mock_create_session.call_args
# Verify session_id matches
assert call_args[0][0] == result["session_id"]
# Verify user was passed
assert call_args[0][1] == sample_user_read
# Verify request was passed
assert call_args[0][2] == mock_request
def test_complete_login_invalid_client_type_raises_error(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test that invalid client type raises 403."""
# Arrange
response = Response()
invalid_client_type = "invalid"
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
auth_utils.complete_login(
response,
mock_request,
sample_user_read,
invalid_client_type,
password_hasher,
token_manager,
mock_db,
)
assert exc_info.value.status_code == 403
assert "Invalid client type" in exc_info.value.detail
def test_complete_login_sets_secure_cookie_for_https(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test that secure flag is set on cookie when using HTTPS."""
# Arrange
response = Response()
client_type = "web"
with patch("auth.utils.session_utils.create_session"):
with patch.dict("os.environ", {"FRONTEND_PROTOCOL": "https"}):
# Act
auth_utils.complete_login(
response,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
# Assert
set_cookie_header = response.headers.get("set-cookie", "")
assert "secure" in set_cookie_header.lower()
def test_complete_login_cookie_attributes(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test that refresh token cookie has correct security attributes."""
# Arrange
response = Response()
client_type = "web"
with patch("auth.utils.session_utils.create_session"):
# Act
auth_utils.complete_login(
response,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
# Assert
set_cookie_header = response.headers.get("set-cookie", "")
assert "endurain_refresh_token" in set_cookie_header
assert "httponly" in set_cookie_header.lower()
assert "samesite=strict" in set_cookie_header.lower()
assert "path=/" in set_cookie_header.lower()
def test_complete_login_returns_different_tokens_on_each_call(
self, password_hasher, token_manager, mock_db, sample_user_read, mock_request
):
"""Test that each login generates unique tokens."""
# Arrange
response1 = Response()
response2 = Response()
client_type = "web"
with patch("auth.utils.session_utils.create_session"):
# Act
result1 = auth_utils.complete_login(
response1,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
result2 = auth_utils.complete_login(
response2,
mock_request,
sample_user_read,
client_type,
password_hasher,
token_manager,
mock_db,
)
# Assert
assert result1["session_id"] != result2["session_id"]
assert result1["access_token"] != result2["access_token"]
assert result1["csrf_token"] != result2["csrf_token"]

View File

@@ -1,3 +0,0 @@
"""
Tests for Endurain session module backend application.
"""

View File

@@ -1,994 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException, status
class TestLoginEndpointSecurity:
"""
Test suite for verifying the security and behavior of the login endpoint.
This class contains tests that cover various scenarios for the login endpoint, including:
- Successful login without Multi-Factor Authentication (MFA) for different client types.
- Login attempts when MFA is required, ensuring the correct response and MFA flow.
- Handling of invalid client types, ensuring forbidden access is enforced.
- Login attempts with invalid credentials, ensuring proper error handling.
- Login attempts with inactive users, ensuring access is denied as expected.
Each test uses extensive mocking to simulate authentication, user activity checks, MFA status, and session/token creation, allowing for isolated and reliable testing of the endpoint's logic and security requirements.
"""
@pytest.mark.parametrize(
"client_type, expected_status, returns_tokens",
[
("web", status.HTTP_200_OK, False),
("mobile", status.HTTP_200_OK, True),
],
)
def test_login_without_mfa(
self,
fast_api_app,
fast_api_client,
sample_user_read,
client_type,
expected_status,
returns_tokens,
):
"""
Test the login endpoint behavior when Multi-Factor Authentication (MFA) is not enabled for the user.
This test verifies that:
- The login process completes successfully without requiring MFA.
- The correct response is returned based on whether tokens are expected.
- The appropriate authentication, user activity, and MFA checks are patched and simulated.
- The fake store is not called during the process.
Args:
fast_api_app: The FastAPI application instance under test.
fast_api_client: The test client for making HTTP requests to the FastAPI app.
sample_user_read: A sample user object returned by the authentication mock.
client_type: The type of client making the request (used in headers and app state).
expected_status: The expected HTTP status code of the response.
returns_tokens: Boolean indicating if the endpoint should return tokens or just a session ID.
"""
fast_api_app.state._client_type = client_type
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
"session.router.users_utils.check_user_is_active"
), patch(
"session.router.profile_utils.is_mfa_enabled_for_user"
) as mock_mfa, patch(
"session.router.auth_utils.complete_login"
) as mock_complete:
mock_auth.return_value = sample_user_read
mock_mfa.return_value = False
mock_complete.return_value = (
{"session_id": "test-session"}
if not returns_tokens
else {
"access_token": "token",
"refresh_token": "refresh",
"session_id": "session",
"token_type": "Bearer",
"expires_in": 900,
}
)
resp = fast_api_client.post(
"/auth/login",
data={"username": "testuser", "password": "secret"},
headers={"X-Client-Type": client_type},
)
assert resp.status_code == expected_status
body = resp.json()
if returns_tokens:
assert body["access_token"] == "token"
assert body["refresh_token"] == "refresh"
assert body["session_id"] == "session"
assert body["token_type"] == "Bearer"
assert isinstance(body["expires_in"], int)
else:
assert body == {"session_id": "test-session"}
assert fast_api_app.state.fake_store.calls == []
@pytest.mark.parametrize(
"client_type, expected_status",
[
("web", status.HTTP_202_ACCEPTED),
("mobile", status.HTTP_200_OK),
],
)
def test_login_with_mfa_required(
self,
fast_api_app,
fast_api_client,
sample_user_read,
client_type,
expected_status,
):
"""
Test the login endpoint when Multi-Factor Authentication (MFA) is required.
This test verifies that when a user with MFA enabled attempts to log in,
the API responds with the correct status code and indicates that MFA is required.
It mocks the authentication, user activity check, and MFA status check to simulate
the scenario where MFA is enabled for the user.
Args:
fast_api_app: The FastAPI application instance under test.
fast_api_client: The test client for making HTTP requests to the FastAPI app.
sample_user_read: A sample user object returned by the authentication mock.
client_type: The type of client making the request (used in headers).
expected_status: The expected HTTP status code for the response.
Asserts:
- The response status code matches the expected status.
- The response JSON contains 'mfa_required' set to True.
- The response JSON contains the correct 'username'.
- The fake_store in the app state records the correct call.
"""
fast_api_app.state._client_type = client_type
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
"session.router.users_utils.check_user_is_active"
), patch("session.router.profile_utils.is_mfa_enabled_for_user") as mock_mfa:
mock_auth.return_value = sample_user_read
mock_mfa.return_value = True
resp = fast_api_client.post(
"/auth/login",
data={"username": "testuser", "password": "secret"},
headers={"X-Client-Type": client_type},
)
assert resp.status_code == expected_status
body = resp.json()
assert body["mfa_required"] is True
assert body["username"] == "testuser"
assert fast_api_app.state.fake_store.calls == [
("testuser", sample_user_read.id)
]
def test_invalid_client_type_forbidden(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test that a login attempt with an invalid client type returns a 403 Forbidden response.
This test sets the application's client type to "desktop" and mocks the authentication,
user activity check, MFA status, token creation, and session creation utilities. It then
sends a POST request to the "/auth/login" endpoint with the "X-Client-Type" header set to "desktop".
The test asserts that the response status code is 403 Forbidden and the response detail
indicates an invalid client type.
Args:
fast_api_app: The FastAPI application instance.
fast_api_client: The test client for making HTTP requests.
sample_user_read: A sample user object returned by the authentication mock.
"""
fast_api_app.state._client_type = "desktop"
with patch("session.router.auth_utils.authenticate_user") as mock_auth, patch(
"session.router.users_utils.check_user_is_active"
), patch(
"session.router.profile_utils.is_mfa_enabled_for_user"
) as mock_mfa, patch(
"session.router.auth_utils.create_tokens"
) as mock_create_tokens, patch(
"session.router.session_utils.create_session"
) as mock_create_session:
mock_auth.return_value = sample_user_read
mock_mfa.return_value = False
mock_create_tokens.return_value = (
"sid",
object(),
"acc",
object(),
"ref",
"csrf",
)
mock_create_session.return_value = None
resp = fast_api_client.post(
"/auth/login",
data={"username": "x", "password": "y"},
headers={"X-Client-Type": "desktop"},
)
assert resp.status_code == status.HTTP_403_FORBIDDEN
assert resp.json()["detail"] == "Invalid client type"
def test_login_with_invalid_credentials(self, password_hasher, mock_db):
"""
Test that the login endpoint raises an HTTPException with status code 401
when invalid credentials are provided. Mocks the authenticate_user function
to simulate authentication failure and verifies that the exception is raised
with the correct status code and detail.
"""
with patch("session.router.auth_utils.authenticate_user") as mock_auth:
mock_auth.side_effect = HTTPException(
status_code=401, detail="Invalid username"
)
with pytest.raises(HTTPException) as exc_info:
mock_auth("invalid", "password", password_hasher, mock_db)
assert exc_info.value.status_code == 401
def test_login_with_inactive_user(self, sample_inactive_user):
"""
Test that the login endpoint raises an HTTPException with status code 403
when attempting to authenticate an inactive user.
This test mocks the authentication and user activity check utilities to simulate
the scenario where a user is found but is inactive. It asserts that the correct
exception is raised with the expected status code.
"""
with patch("session.router.auth_utils.authenticate_user") as mock_auth:
with patch("session.router.users_utils.check_user_is_active") as mock_check:
mock_auth.return_value = sample_inactive_user
mock_check.side_effect = HTTPException(
status_code=403, detail="User is inactive"
)
with pytest.raises(HTTPException) as exc_info:
mock_check(sample_inactive_user)
assert exc_info.value.status_code == 403
class TestMFAVerifyEndpoint:
"""
Test suite for the MFA verification endpoint (/auth/mfa/verify).
This class contains tests that cover various scenarios for the MFA verification endpoint, including:
- Successful MFA verification and login for different client types (web and mobile).
- Handling of cases where no pending MFA login is found.
- Handling of invalid MFA codes.
- Handling of cases where the user is not found after MFA verification.
- Handling of inactive users after MFA verification.
- Handling of invalid client types during MFA verification.
Each test uses mocking to simulate the MFA verification flow, user lookup, and session creation.
"""
@pytest.mark.parametrize(
"client_type, expected_status, returns_tokens",
[
("web", status.HTTP_200_OK, False),
("mobile", status.HTTP_200_OK, True),
],
)
def test_mfa_verify_success(
self,
fast_api_app,
fast_api_client,
sample_user_read,
client_type,
expected_status,
returns_tokens,
):
"""
Test successful MFA verification and login completion.
This test verifies that when a valid MFA code is provided for a pending login,
the API successfully completes the login process and returns appropriate tokens
based on the client type.
Args:
fast_api_app: The FastAPI application instance under test.
fast_api_client: The test client for making HTTP requests.
sample_user_read: A sample user object.
client_type: The type of client ("web" or "mobile").
expected_status: The expected HTTP status code.
returns_tokens: Boolean indicating if tokens should be returned.
"""
fast_api_app.state._client_type = client_type
# Setup pending MFA login
pending_store = fast_api_app.state.fake_store
pending_store._store = {"testuser": sample_user_read.id}
with patch(
"session.router.profile_utils.verify_user_mfa"
) as mock_verify_mfa, patch(
"session.router.users_crud.get_user_by_id"
) as mock_get_user, patch(
"session.router.users_utils.check_user_is_active"
), patch(
"session.router.auth_utils.complete_login"
) as mock_complete:
mock_verify_mfa.return_value = True
mock_get_user.return_value = sample_user_read
mock_complete.return_value = (
{"session_id": "test-session"}
if not returns_tokens
else {
"access_token": "token",
"refresh_token": "refresh",
"session_id": "session",
"token_type": "Bearer",
"expires_in": 900,
}
)
resp = fast_api_client.post(
"/auth/mfa/verify",
json={"username": "testuser", "mfa_code": "123456"},
headers={"X-Client-Type": client_type},
)
assert resp.status_code == expected_status
body = resp.json()
if returns_tokens:
assert body["access_token"] == "token"
assert body["refresh_token"] == "refresh"
assert body["session_id"] == "session"
else:
assert body["session_id"] == "test-session"
def test_mfa_verify_no_pending_login(self, fast_api_app, fast_api_client):
"""
Test MFA verification when no pending login is found.
This test verifies that when attempting to verify MFA without a pending login,
the API returns a 400 Bad Request error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.fake_store._store = {}
resp = fast_api_client.post(
"/auth/mfa/verify",
json={"username": "testuser", "mfa_code": "123456"},
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_400_BAD_REQUEST
assert "No pending MFA login" in resp.json()["detail"]
def test_mfa_verify_invalid_code(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test MFA verification with an invalid MFA code.
This test verifies that when an invalid MFA code is provided,
the API returns a 401 Unauthorized error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.fake_store._store = {"testuser": sample_user_read.id}
with patch("session.router.profile_utils.verify_user_mfa") as mock_verify_mfa:
mock_verify_mfa.return_value = False
resp = fast_api_client.post(
"/auth/mfa/verify",
json={"username": "testuser", "mfa_code": "999999"},
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
assert "Invalid MFA code" in resp.json()["detail"]
def test_mfa_verify_user_not_found(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test MFA verification when user is not found after verification.
This test verifies that when a user cannot be found in the database
after MFA verification, the API returns a 404 Not Found error and
cleans up the pending login.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.fake_store._store = {"testuser": sample_user_read.id}
with patch(
"session.router.profile_utils.verify_user_mfa"
) as mock_verify_mfa, patch(
"session.router.users_crud.get_user_by_id"
) as mock_get_user:
mock_verify_mfa.return_value = True
mock_get_user.return_value = None
resp = fast_api_client.post(
"/auth/mfa/verify",
json={"username": "testuser", "mfa_code": "123456"},
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_404_NOT_FOUND
assert "User not found" in resp.json()["detail"]
def test_mfa_verify_inactive_user(
self, fast_api_app, fast_api_client, sample_inactive_user
):
"""
Test MFA verification with an inactive user.
This test verifies that when an inactive user attempts to complete MFA login,
the API returns a 403 Forbidden error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.fake_store._store = {"inactive": sample_inactive_user.id}
with patch(
"session.router.profile_utils.verify_user_mfa"
) as mock_verify_mfa, patch(
"session.router.users_crud.get_user_by_id"
) as mock_get_user, patch(
"session.router.users_utils.check_user_is_active"
) as mock_check_active:
mock_verify_mfa.return_value = True
mock_get_user.return_value = sample_inactive_user
mock_check_active.side_effect = HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User is inactive"
)
resp = fast_api_client.post(
"/auth/mfa/verify",
json={"username": "inactive", "mfa_code": "123456"},
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_403_FORBIDDEN
assert "User is inactive" in resp.json()["detail"]
class TestRefreshTokenEndpoint:
"""
Test suite for the refresh token endpoint (/auth/refresh).
This class contains tests that cover various scenarios for the token refresh endpoint, including:
- Successful token refresh for different client types (web and mobile).
- Handling of session not found errors.
- Handling of invalid refresh token hash mismatches.
- Handling of inactive users during refresh.
- Handling of invalid client types during refresh.
Each test uses mocking to simulate token validation, session retrieval, and token creation.
"""
@pytest.mark.parametrize(
"client_type, expected_status, returns_tokens",
[
("web", status.HTTP_200_OK, False),
("mobile", status.HTTP_200_OK, True),
],
)
def test_refresh_token_success(
self,
fast_api_app,
fast_api_client,
sample_user_read,
password_hasher,
client_type,
expected_status,
returns_tokens,
):
"""
Test successful token refresh.
This test verifies that when a valid refresh token is provided,
the API successfully creates new tokens and returns them based on client type.
Args:
fast_api_app: The FastAPI application instance under test.
fast_api_client: The test client for making HTTP requests.
sample_user_read: A sample user object.
password_hasher: The password hasher instance.
mock_db: Mock database session.
client_type: The type of client ("web" or "mobile").
expected_status: The expected HTTP status code.
returns_tokens: Boolean indicating if tokens should be returned.
"""
fast_api_app.state._client_type = client_type
fast_api_app.state.mock_user_id = sample_user_read.id
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_refresh_token = "refresh_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
mock_session.refresh_token = password_hasher.hash_password(
"refresh_token_value"
)
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
), patch(
"session.router.users_crud.get_user_by_id", return_value=sample_user_read
), patch(
"session.router.users_utils.check_user_is_active"
), patch(
"session.router.auth_utils.create_tokens"
) as mock_create_tokens, patch(
"session.router.session_utils.edit_session"
), patch(
"session.router.auth_utils.create_response_with_tokens",
side_effect=lambda r, a, rf, c: r,
):
# Set up proper mock for create_tokens with timestamp
mock_access_exp = MagicMock()
mock_access_exp.timestamp.return_value = 1234567890
mock_refresh_exp = MagicMock()
mock_create_tokens.return_value = (
"new-session-id",
mock_access_exp,
"new_access_token",
mock_refresh_exp,
"new_refresh_token",
"new_csrf_token",
)
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/refresh",
headers={"X-Client-Type": client_type},
)
assert resp.status_code == expected_status
body = resp.json()
if returns_tokens:
assert body["access_token"] == "new_access_token"
assert body["refresh_token"] == "new_refresh_token"
assert body["session_id"] == "new-session-id"
assert body["token_type"] == "bearer"
else:
assert body["session_id"] == "new-session-id"
def test_refresh_token_session_not_found(self, fast_api_app, fast_api_client):
"""
Test token refresh when session is not found.
This test verifies that when attempting to refresh with a session ID
that doesn't exist, the API returns a 404 Not Found error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.mock_user_id = 1
fast_api_app.state.mock_session_id = "nonexistent-session"
fast_api_app.state.mock_refresh_token = "refresh_token_value"
with patch("session.router.session_crud.get_session_by_id", return_value=None):
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/refresh",
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_404_NOT_FOUND
assert "Session not found" in resp.json()["detail"]
def test_refresh_token_invalid_hash(
self, fast_api_app, fast_api_client, password_hasher
):
"""
Test token refresh with invalid refresh token hash.
This test verifies that when the refresh token hash doesn't match
the stored hash, the API returns a 401 Unauthorized error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.mock_user_id = 1
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_refresh_token = "wrong_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
mock_session.refresh_token = password_hasher.hash_password("different_token")
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
):
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_refresh_token", "wrong_token_value")
resp = fast_api_client.post(
"/auth/refresh",
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
assert "Invalid refresh token" in resp.json()["detail"]
def test_refresh_token_inactive_user(
self, fast_api_app, fast_api_client, sample_inactive_user, password_hasher
):
"""
Test token refresh with an inactive user.
This test verifies that when an inactive user attempts to refresh tokens,
the API returns a 403 Forbidden error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.mock_user_id = sample_inactive_user.id
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_refresh_token = "refresh_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
mock_session.refresh_token = password_hasher.hash_password(
"refresh_token_value"
)
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
), patch(
"session.router.users_crud.get_user_by_id",
return_value=sample_inactive_user,
), patch(
"session.router.users_utils.check_user_is_active",
side_effect=HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User is inactive"
),
):
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/refresh",
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_403_FORBIDDEN
def test_refresh_token_invalid_client_type(
self, fast_api_app, fast_api_client, sample_user_read, password_hasher
):
"""
Test token refresh with an invalid client type.
This test verifies that when an invalid client type is provided,
the API returns a 403 Forbidden error.
"""
fast_api_app.state._client_type = "desktop"
fast_api_app.state.mock_user_id = sample_user_read.id
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_refresh_token = "refresh_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
mock_session.refresh_token = password_hasher.hash_password(
"refresh_token_value"
)
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
), patch(
"session.router.users_crud.get_user_by_id", return_value=sample_user_read
), patch(
"session.router.users_utils.check_user_is_active"
), patch(
"session.router.auth_utils.create_tokens"
) as mock_create_tokens, patch(
"session.router.session_utils.edit_session"
):
# Set up proper mock for create_tokens with timestamp
mock_access_exp = MagicMock()
mock_access_exp.timestamp.return_value = 1234567890
mock_refresh_exp = MagicMock()
mock_create_tokens.return_value = (
"new-session-id",
mock_access_exp,
"new_access_token",
mock_refresh_exp,
"new_refresh_token",
"new_csrf_token",
)
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/refresh",
headers={"X-Client-Type": "desktop"},
)
assert resp.status_code == status.HTTP_403_FORBIDDEN
assert "Invalid client type" in resp.json()["detail"]
class TestLogoutEndpoint:
"""
Test suite for the logout endpoint (/auth/logout).
This class contains tests that cover various scenarios for the logout endpoint, including:
- Successful logout for different client types (web and mobile).
- Cookie clearing for web clients.
- Handling of invalid refresh tokens during logout.
- Handling of session not found during logout (should still succeed).
- Handling of invalid client types during logout.
Each test uses mocking to simulate token validation, session retrieval, and session deletion.
"""
@pytest.mark.parametrize(
"client_type, expected_status",
[
("web", status.HTTP_200_OK),
("mobile", status.HTTP_200_OK),
],
)
def test_logout_success(
self,
fast_api_app,
fast_api_client,
password_hasher,
client_type,
expected_status,
):
"""
Test successful logout.
This test verifies that when a valid access and refresh token are provided,
the API successfully deletes the session and returns a success message.
Args:
fast_api_app: The FastAPI application instance under test.
fast_api_client: The test client for making HTTP requests.
password_hasher: The password hasher instance.
client_type: The type of client ("web" or "mobile").
expected_status: The expected HTTP status code.
"""
fast_api_app.state._client_type = client_type
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_user_id = 1
fast_api_app.state.mock_refresh_token = "refresh_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
mock_session.refresh_token = password_hasher.hash_password(
"refresh_token_value"
)
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
), patch("session.router.session_crud.delete_session") as mock_delete:
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_access_token", "access_token")
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/logout",
headers={"X-Client-Type": client_type},
)
assert resp.status_code == expected_status
assert resp.json()["message"] == "Logout successful"
mock_delete.assert_called_once()
# Check cookies are cleared for web clients
if client_type == "web":
# The response should have set-cookie headers to clear cookies
assert "set-cookie" in resp.headers or resp.cookies
def test_logout_invalid_refresh_token(
self, fast_api_app, fast_api_client, password_hasher
):
"""
Test logout with an invalid refresh token.
This test verifies that when the refresh token hash doesn't match,
the API returns a 401 Unauthorized error.
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_user_id = 1
fast_api_app.state.mock_refresh_token = "wrong_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
mock_session.refresh_token = password_hasher.hash_password("different_token")
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
):
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_access_token", "access_token")
fast_api_client.cookies.set("endurain_refresh_token", "wrong_token_value")
resp = fast_api_client.post(
"/auth/logout",
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
assert "Invalid refresh token" in resp.json()["detail"]
def test_logout_session_not_found_still_succeeds(
self, fast_api_app, fast_api_client
):
"""
Test logout when session is not found (should still succeed).
This test verifies that when attempting to logout with a session ID
that doesn't exist, the API still returns success (idempotent operation).
"""
fast_api_app.state._client_type = "web"
fast_api_app.state.mock_session_id = "nonexistent-session"
fast_api_app.state.mock_user_id = 1
fast_api_app.state.mock_refresh_token = "refresh_token_value"
with patch("session.router.session_crud.get_session_by_id", return_value=None):
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_access_token", "access_token")
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/logout",
headers={"X-Client-Type": "web"},
)
assert resp.status_code == status.HTTP_200_OK
assert resp.json()["message"] == "Logout successful"
def test_logout_invalid_client_type(self, fast_api_app, fast_api_client):
"""
Test logout with an invalid client type.
This test verifies that when an invalid client type is provided,
the API returns a 401 Unauthorized error (client type validation
happens after authentication in the dependency chain).
"""
fast_api_app.state._client_type = "desktop"
fast_api_app.state.mock_session_id = "test-session-id"
fast_api_app.state.mock_user_id = 1
fast_api_app.state.mock_refresh_token = "refresh_token_value"
mock_session = MagicMock()
mock_session.id = "test-session-id"
with patch(
"session.router.session_crud.get_session_by_id", return_value=mock_session
):
# Set cookies on client instance (new API)
fast_api_client.cookies.set("endurain_access_token", "access_token")
fast_api_client.cookies.set("endurain_refresh_token", "refresh_token_value")
resp = fast_api_client.post(
"/auth/logout",
headers={"X-Client-Type": "desktop"},
)
# Client type validation happens in the header_client_type_scheme dependency
# which runs after authentication, so we get 401 due to invalid client type
# being rejected by the scheme validator
assert resp.status_code == status.HTTP_401_UNAUTHORIZED
class TestSessionsEndpoints:
"""
Test suite for the sessions management endpoints.
This class contains tests for:
- GET /sessions/user/{user_id} - Retrieve all sessions for a user
- DELETE /sessions/{session_id}/user/{user_id} - Delete a specific session
Each test uses mocking to simulate authentication, authorization, and database operations.
"""
def test_read_sessions_user_success(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test successful retrieval of user sessions.
This test verifies that when a valid access token is provided with
appropriate scopes, the API returns all sessions for the user.
"""
fast_api_app.state._client_type = "web"
mock_sessions = [
{
"id": "session-1",
"user_id": sample_user_read.id,
"device_type": "desktop",
"browser": "Chrome",
},
{
"id": "session-2",
"user_id": sample_user_read.id,
"device_type": "mobile",
"browser": "Safari",
},
]
with patch(
"session.router.session_crud.get_user_sessions", return_value=mock_sessions
) as mock_get_sessions:
resp = fast_api_client.get(
f"/sessions/user/{sample_user_read.id}",
headers={
"X-Client-Type": "web",
"Authorization": "Bearer access_token",
},
)
assert resp.status_code == status.HTTP_200_OK
assert len(resp.json()) == 2
assert resp.json()[0]["id"] == "session-1"
assert resp.json()[1]["id"] == "session-2"
mock_get_sessions.assert_called_once()
def test_read_sessions_user_empty_list(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test retrieval of user sessions when no sessions exist.
This test verifies that when a user has no active sessions,
the API returns an empty list.
"""
fast_api_app.state._client_type = "web"
with patch("session.router.session_crud.get_user_sessions", return_value=[]):
resp = fast_api_client.get(
f"/sessions/user/{sample_user_read.id}",
headers={
"X-Client-Type": "web",
"Authorization": "Bearer access_token",
},
)
assert resp.status_code == status.HTTP_200_OK
assert resp.json() == []
def test_delete_session_success(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test successful deletion of a user session.
This test verifies that when a valid access token is provided with
appropriate scopes, the API successfully deletes the specified session.
"""
fast_api_app.state._client_type = "web"
session_id = "session-to-delete"
with patch(
"session.router.session_crud.delete_session", return_value=True
) as mock_delete:
resp = fast_api_client.delete(
f"/sessions/{session_id}/user/{sample_user_read.id}",
headers={
"X-Client-Type": "web",
"Authorization": "Bearer access_token",
},
)
assert resp.status_code == status.HTTP_200_OK
# Verify delete_session was called with the correct session_id and user_id
# (the third argument is the database session which we don't need to verify)
assert mock_delete.called
call_args = mock_delete.call_args[0]
assert call_args[0] == session_id
assert call_args[1] == sample_user_read.id
def test_delete_session_not_found(
self, fast_api_app, fast_api_client, sample_user_read
):
"""
Test deletion of a non-existent session.
This test verifies that when attempting to delete a session that doesn't exist,
the API handles it appropriately (implementation-dependent behavior).
"""
fast_api_app.state._client_type = "web"
session_id = "nonexistent-session"
with patch(
"session.router.session_crud.delete_session",
side_effect=HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session not found"
),
):
resp = fast_api_client.delete(
f"/sessions/{session_id}/user/{sample_user_read.id}",
headers={
"X-Client-Type": "web",
"Authorization": "Bearer access_token",
},
)
assert resp.status_code == status.HTTP_404_NOT_FOUND
assert "Session not found" in resp.json()["detail"]

File diff suppressed because it is too large Load Diff