mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
refactor(backend): enhance storage and retrieval of blocked domains (#12273)
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
"""create blocked_email_domains table
|
||||
|
||||
Revision ID: 086
|
||||
Revises: 085
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '086'
|
||||
down_revision: Union[str, None] = '085'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create blocked_email_domains table for storing blocked email domain patterns."""
|
||||
op.create_table(
|
||||
'blocked_email_domains',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('domain', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
# Create unique index on domain column
|
||||
op.create_index(
|
||||
'ix_blocked_email_domains_domain',
|
||||
'blocked_email_domains',
|
||||
['domain'],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop blocked_email_domains table."""
|
||||
op.drop_index('ix_blocked_email_domains_domain', table_name='blocked_email_domains')
|
||||
op.drop_table('blocked_email_domains')
|
||||
@@ -38,8 +38,3 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'y',
|
||||
'on',
|
||||
)
|
||||
BLOCKED_EMAIL_DOMAINS = [
|
||||
domain.strip().lower()
|
||||
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
|
||||
if domain.strip()
|
||||
]
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
|
||||
from storage.blocked_email_domain_store import BlockedEmailDomainStore
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, store: BlockedEmailDomainStore) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
|
||||
if self.blocked_domains:
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
|
||||
)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if domain blocking is enabled"""
|
||||
return bool(self.blocked_domains)
|
||||
self.store = store
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
@@ -31,16 +24,16 @@ class DomainBlocker:
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked
|
||||
"""Check if email domain is blocked by querying the database directly via SQL.
|
||||
|
||||
Supports blocking:
|
||||
- Exact domains: 'example.com' blocks 'user@example.com'
|
||||
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
|
||||
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
|
||||
"""
|
||||
if not self.is_active():
|
||||
return False
|
||||
|
||||
The blocking logic is handled efficiently in SQL, avoiding the need to load
|
||||
all blocked domains into memory.
|
||||
"""
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
@@ -50,26 +43,25 @@ class DomainBlocker:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
# Check if domain matches any blocked pattern
|
||||
for blocked_pattern in self.blocked_domains:
|
||||
if blocked_pattern.startswith('.'):
|
||||
# TLD pattern (e.g., '.us') - check if domain ends with it
|
||||
if domain.endswith(blocked_pattern):
|
||||
logger.warning(
|
||||
f'Email domain {domain} is blocked by TLD pattern {blocked_pattern} for email: {email}'
|
||||
)
|
||||
return True
|
||||
try:
|
||||
# Query database directly via SQL to check if domain is blocked
|
||||
is_blocked = self.store.is_domain_blocked(domain)
|
||||
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
# Full domain pattern (e.g., 'example.com')
|
||||
# Block exact match or subdomains
|
||||
if domain == blocked_pattern or domain.endswith(f'.{blocked_pattern}'):
|
||||
logger.warning(
|
||||
f'Email domain {domain} is blocked by domain pattern {blocked_pattern} for email: {email}'
|
||||
)
|
||||
return True
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
return False
|
||||
return is_blocked
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error checking if domain is blocked for email {email}: {e}',
|
||||
exc_info=True,
|
||||
)
|
||||
# Fail-safe: if database query fails, don't block (allow auth to proceed)
|
||||
return False
|
||||
|
||||
|
||||
domain_blocker = DomainBlocker()
|
||||
# Initialize store and domain blocker
|
||||
_store = BlockedEmailDomainStore(session_maker=session_maker)
|
||||
domain_blocker = DomainBlocker(store=_store)
|
||||
|
||||
@@ -317,7 +317,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
|
||||
@@ -151,7 +151,7 @@ async def keycloak_callback(
|
||||
|
||||
# Check if email domain is blocked
|
||||
email = user_info.get('email')
|
||||
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
|
||||
30
enterprise/storage/blocked_email_domain.py
Normal file
30
enterprise/storage/blocked_email_domain.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class BlockedEmailDomain(Base): # type: ignore
|
||||
"""Stores blocked email domain patterns.
|
||||
|
||||
Supports blocking:
|
||||
- Exact domains: 'example.com' blocks 'user@example.com'
|
||||
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
|
||||
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
|
||||
"""
|
||||
|
||||
__tablename__ = 'blocked_email_domains'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
domain = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
45
enterprise/storage/blocked_email_domain_store.py
Normal file
45
enterprise/storage/blocked_email_domain_store.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockedEmailDomainStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
def is_domain_blocked(self, domain: str) -> bool:
|
||||
"""Check if a domain is blocked by querying the database directly.
|
||||
|
||||
This method uses SQL to efficiently check if the domain matches any blocked pattern:
|
||||
- TLD patterns (e.g., '.us'): checks if domain ends with the pattern
|
||||
- Full domain patterns (e.g., 'example.com'): checks for exact match or subdomain match
|
||||
|
||||
Args:
|
||||
domain: The extracted domain from the email (e.g., 'example.com' or 'subdomain.example.com')
|
||||
|
||||
Returns:
|
||||
True if the domain is blocked, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# SQL query that handles both TLD patterns and full domain patterns
|
||||
# TLD patterns (starting with '.'): check if domain ends with the pattern
|
||||
# Full domain patterns: check for exact match or subdomain match
|
||||
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
|
||||
query = text("""
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM blocked_email_domains
|
||||
WHERE
|
||||
-- TLD pattern (e.g., '.us') - check if domain ends with it (case-insensitive)
|
||||
(LOWER(domain) LIKE '.%' AND LOWER(:domain) LIKE '%' || LOWER(domain)) OR
|
||||
-- Full domain pattern (e.g., 'example.com')
|
||||
-- Block exact match or subdomains (case-insensitive)
|
||||
(LOWER(domain) NOT LIKE '.%' AND (
|
||||
LOWER(:domain) = LOWER(domain) OR
|
||||
LOWER(:domain) LIKE '%.' || LOWER(domain)
|
||||
))
|
||||
)
|
||||
""")
|
||||
result = session.execute(query, {'domain': domain}).scalar()
|
||||
return bool(result)
|
||||
@@ -546,7 +546,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
@@ -600,7 +599,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -621,7 +619,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
"""Test keycloak_callback when domain blocking is not active."""
|
||||
"""Test keycloak_callback when email domain is not blocked."""
|
||||
# Arrange
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
@@ -654,7 +652,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -666,7 +664,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, RedirectResponse)
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
mock_token_manager.disable_keycloak_user.assert_not_called()
|
||||
|
||||
|
||||
@@ -705,8 +703,6 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
|
||||
@@ -1,33 +1,21 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def domain_blocker():
|
||||
"""Create a DomainBlocker instance for testing."""
|
||||
return DomainBlocker()
|
||||
def mock_store():
|
||||
"""Create a mock BlockedEmailDomainStore for testing."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'blocked_domains,expected',
|
||||
[
|
||||
(['colsch.us', 'other-domain.com'], True),
|
||||
(['example.com'], True),
|
||||
([], False),
|
||||
],
|
||||
)
|
||||
def test_is_active(domain_blocker, blocked_domains, expected):
|
||||
"""Test that is_active returns correct value based on blocked domains configuration."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = blocked_domains
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_active()
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
@pytest.fixture
|
||||
def domain_blocker(mock_store):
|
||||
"""Create a DomainBlocker instance for testing with a mocked store."""
|
||||
return DomainBlocker(store=mock_store)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -69,94 +57,104 @@ def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_is_domain_blocked_when_inactive(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when blocking is not active."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = []
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker):
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker):
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker):
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when domain is not blocked."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker):
|
||||
"""Test that is_domain_blocked returns True when domain is in blocked list."""
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns True when domain is blocked."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain matching."""
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain extraction."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked correctly checks multiple domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = lambda domain: domain in [
|
||||
'other-domain.com',
|
||||
'blocked.org',
|
||||
]
|
||||
|
||||
# Act
|
||||
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
@@ -167,109 +165,71 @@ def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is False
|
||||
assert mock_store.is_domain_blocked.call_count == 3
|
||||
|
||||
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['colsch.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TLD Blocking Tests (patterns starting with '.')
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(domain_blocker):
|
||||
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern blocks domains ending with that TLD."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@company.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.us')
|
||||
|
||||
|
||||
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(domain_blocker):
|
||||
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern blocks subdomains with that TLD."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@subdomain.company.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us')
|
||||
|
||||
|
||||
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(domain_blocker):
|
||||
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern does not block domains with different TLD."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us']
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@company.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.com')
|
||||
|
||||
|
||||
def test_is_domain_blocked_tld_pattern_does_not_block_substring_match(
|
||||
domain_blocker,
|
||||
):
|
||||
"""Test that TLD pattern does not block domains that contain but don't end with the TLD."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@focus.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker):
|
||||
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker, mock_store):
|
||||
"""Test that TLD pattern matching is case-insensitive."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@COMPANY.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.us')
|
||||
|
||||
|
||||
def test_is_domain_blocked_multiple_tld_patterns(domain_blocker):
|
||||
"""Test blocking with multiple TLD patterns."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us', '.vn', '.com']
|
||||
|
||||
# Act
|
||||
result_us = domain_blocker.is_domain_blocked('user@test.us')
|
||||
result_vn = domain_blocker.is_domain_blocked('user@test.vn')
|
||||
result_com = domain_blocker.is_domain_blocked('user@test.com')
|
||||
result_org = domain_blocker.is_domain_blocked('user@test.org')
|
||||
|
||||
# Assert
|
||||
assert result_us is True
|
||||
assert result_vn is True
|
||||
assert result_com is True
|
||||
assert result_org is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker):
|
||||
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker, mock_store):
|
||||
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.co.uk']
|
||||
mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk')
|
||||
|
||||
# Act
|
||||
result_match = domain_blocker.is_domain_blocked('user@example.co.uk')
|
||||
@@ -282,81 +242,87 @@ def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker):
|
||||
assert result_no_match is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Subdomain Blocking Tests (domain patterns now block subdomains)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_pattern_blocks_exact_match(domain_blocker):
|
||||
def test_is_domain_blocked_domain_pattern_blocks_exact_match(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks exact domain match."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['example.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker):
|
||||
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker, mock_store):
|
||||
"""Test that domain pattern blocks subdomains of that domain."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['example.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@subdomain.example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com')
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
|
||||
domain_blocker,
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks multi-level subdomains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['example.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@api.v2.example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com')
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
|
||||
domain_blocker,
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['example.com']
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@notexample.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('notexample.com')
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
|
||||
domain_blocker,
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern does not block same domain with different TLD."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['example.com']
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.org')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.org')
|
||||
|
||||
|
||||
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_blocker):
|
||||
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that blocking a subdomain also blocks its nested subdomains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['api.example.com']
|
||||
mock_store.is_domain_blocked.side_effect = (
|
||||
lambda domain: 'api.example.com' in domain
|
||||
)
|
||||
|
||||
# Act
|
||||
result_exact = domain_blocker.is_domain_blocked('user@api.example.com')
|
||||
@@ -369,80 +335,10 @@ def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_bloc
|
||||
assert result_parent is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mixed Pattern Tests (TLD + domain patterns together)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_is_domain_blocked_mixed_patterns_tld_and_domain(domain_blocker):
|
||||
"""Test blocking with both TLD and domain patterns."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us', 'openhands.dev']
|
||||
|
||||
# Act
|
||||
result_tld = domain_blocker.is_domain_blocked('user@company.us')
|
||||
result_domain = domain_blocker.is_domain_blocked('user@openhands.dev')
|
||||
result_subdomain = domain_blocker.is_domain_blocked('user@api.openhands.dev')
|
||||
result_allowed = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result_tld is True
|
||||
assert result_domain is True
|
||||
assert result_subdomain is True
|
||||
assert result_allowed is False
|
||||
|
||||
|
||||
def test_is_domain_blocked_overlapping_patterns(domain_blocker):
|
||||
"""Test that overlapping patterns (TLD and specific domain) both work."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.us', 'test.us']
|
||||
|
||||
# Act
|
||||
result_specific = domain_blocker.is_domain_blocked('user@test.us')
|
||||
result_other_us = domain_blocker.is_domain_blocked('user@other.us')
|
||||
|
||||
# Assert
|
||||
assert result_specific is True
|
||||
assert result_other_us is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_complex_multi_pattern_scenario(domain_blocker):
|
||||
"""Test complex scenario with multiple TLD and domain patterns."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = [
|
||||
'.us',
|
||||
'.vn',
|
||||
'test.com',
|
||||
'openhands.dev',
|
||||
]
|
||||
|
||||
# Act & Assert
|
||||
# TLD patterns
|
||||
assert domain_blocker.is_domain_blocked('user@anything.us') is True
|
||||
assert domain_blocker.is_domain_blocked('user@company.vn') is True
|
||||
|
||||
# Domain patterns (exact)
|
||||
assert domain_blocker.is_domain_blocked('user@test.com') is True
|
||||
assert domain_blocker.is_domain_blocked('user@openhands.dev') is True
|
||||
|
||||
# Domain patterns (subdomains)
|
||||
assert domain_blocker.is_domain_blocked('user@api.test.com') is True
|
||||
assert domain_blocker.is_domain_blocked('user@staging.openhands.dev') is True
|
||||
|
||||
# Not blocked
|
||||
assert domain_blocker.is_domain_blocked('user@allowed.com') is False
|
||||
assert domain_blocker.is_domain_blocked('user@example.org') is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Case Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_with_hyphens(domain_blocker):
|
||||
def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
|
||||
"""Test that domain patterns work with hyphenated domains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['my-company.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result_exact = domain_blocker.is_domain_blocked('user@my-company.com')
|
||||
@@ -451,12 +347,13 @@ def test_is_domain_blocked_domain_with_hyphens(domain_blocker):
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
assert result_subdomain is True
|
||||
assert mock_store.is_domain_blocked.call_count == 2
|
||||
|
||||
|
||||
def test_is_domain_blocked_domain_with_numbers(domain_blocker):
|
||||
def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
|
||||
"""Test that domain patterns work with numeric domains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['test123.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result_exact = domain_blocker.is_domain_blocked('user@test123.com')
|
||||
@@ -465,24 +362,13 @@ def test_is_domain_blocked_domain_with_numbers(domain_blocker):
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
assert result_subdomain is True
|
||||
assert mock_store.is_domain_blocked.call_count == 2
|
||||
|
||||
|
||||
def test_is_domain_blocked_short_tld(domain_blocker):
|
||||
"""Test that short TLD patterns work correctly."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['.io']
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@company.io')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker):
|
||||
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
|
||||
"""Test that blocking works with very long subdomain chains."""
|
||||
# Arrange
|
||||
domain_blocker.blocked_domains = ['example.com']
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked(
|
||||
@@ -491,3 +377,19 @@ def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker):
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with(
|
||||
'level4.level3.level2.level1.example.com'
|
||||
)
|
||||
|
||||
|
||||
def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when store raises an exception."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = Exception('Database connection error')
|
||||
|
||||
# Act
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
@@ -673,7 +673,6 @@ async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act & Assert
|
||||
@@ -703,7 +702,6 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
@@ -720,7 +718,7 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
|
||||
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
|
||||
# Arrange
|
||||
access_payload = {
|
||||
'sub': 'test_user_id',
|
||||
@@ -737,7 +735,7 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
|
||||
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
|
||||
|
||||
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await saas_user_auth_from_signed_token(signed_token)
|
||||
@@ -745,4 +743,4 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
|
||||
# Assert
|
||||
assert isinstance(result, SaasUserAuth)
|
||||
assert result.user_id == 'test_user_id'
|
||||
mock_domain_blocker.is_domain_blocked.assert_not_called()
|
||||
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
|
||||
|
||||
Reference in New Issue
Block a user