refactor(backend): enhance storage and retrieval of blocked domains (#12273)

This commit is contained in:
Hiep Le
2026-01-07 13:41:43 +07:00
committed by GitHub
parent 08df955ba7
commit 8ddb815a89
10 changed files with 280 additions and 268 deletions

View File

@@ -0,0 +1,54 @@
"""create blocked_email_domains table
Revision ID: 086
Revises: 085
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '086'
down_revision: Union[str, None] = '085'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create blocked_email_domains table for storing blocked email domain patterns."""
op.create_table(
'blocked_email_domains',
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
sa.Column('domain', sa.String(), nullable=False),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column(
'updated_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.PrimaryKeyConstraint('id'),
)
# Create unique index on domain column
op.create_index(
'ix_blocked_email_domains_domain',
'blocked_email_domains',
['domain'],
unique=True,
)
def downgrade() -> None:
"""Drop blocked_email_domains table."""
op.drop_index('ix_blocked_email_domains_domain', table_name='blocked_email_domains')
op.drop_table('blocked_email_domains')

View File

@@ -38,8 +38,3 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
'y',
'on',
)
BLOCKED_EMAIL_DOMAINS = [
domain.strip().lower()
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
if domain.strip()
]

View File

@@ -1,20 +1,13 @@
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
from storage.blocked_email_domain_store import BlockedEmailDomainStore
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
class DomainBlocker:
def __init__(self) -> None:
def __init__(self, store: BlockedEmailDomainStore) -> None:
logger.debug('Initializing DomainBlocker')
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
if self.blocked_domains:
logger.info(
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
)
def is_active(self) -> bool:
"""Check if domain blocking is enabled"""
return bool(self.blocked_domains)
self.store = store
def _extract_domain(self, email: str) -> str | None:
"""Extract and normalize email domain from email address"""
@@ -31,16 +24,16 @@ class DomainBlocker:
return None
def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked
"""Check if email domain is blocked by querying the database directly via SQL.
Supports blocking:
- Exact domains: 'example.com' blocks 'user@example.com'
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
"""
if not self.is_active():
return False
The blocking logic is handled efficiently in SQL, avoiding the need to load
all blocked domains into memory.
"""
if not email:
logger.debug('No email provided for domain check')
return False
@@ -50,26 +43,25 @@ class DomainBlocker:
logger.debug(f'Could not extract domain from email: {email}')
return False
# Check if domain matches any blocked pattern
for blocked_pattern in self.blocked_domains:
if blocked_pattern.startswith('.'):
# TLD pattern (e.g., '.us') - check if domain ends with it
if domain.endswith(blocked_pattern):
logger.warning(
f'Email domain {domain} is blocked by TLD pattern {blocked_pattern} for email: {email}'
)
return True
try:
# Query database directly via SQL to check if domain is blocked
is_blocked = self.store.is_domain_blocked(domain)
if is_blocked:
logger.warning(f'Email domain {domain} is blocked for email: {email}')
else:
# Full domain pattern (e.g., 'example.com')
# Block exact match or subdomains
if domain == blocked_pattern or domain.endswith(f'.{blocked_pattern}'):
logger.warning(
f'Email domain {domain} is blocked by domain pattern {blocked_pattern} for email: {email}'
)
return True
logger.debug(f'Email domain {domain} is not blocked')
logger.debug(f'Email domain {domain} is not blocked')
return False
return is_blocked
except Exception as e:
logger.error(
f'Error checking if domain is blocked for email {email}: {e}',
exc_info=True,
)
# Fail-safe: if database query fails, don't block (allow auth to proceed)
return False
domain_blocker = DomainBlocker()
# Initialize store and domain blocker
_store = BlockedEmailDomainStore(session_maker=session_maker)
domain_blocker = DomainBlocker(store=_store)

View File

@@ -317,7 +317,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
if email and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)

View File

@@ -151,7 +151,7 @@ async def keycloak_callback(
# Check if email domain is blocked
email = user_info.get('email')
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
if email and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)

View File

@@ -0,0 +1,30 @@
from datetime import UTC, datetime
from sqlalchemy import Column, DateTime, Identity, Integer, String
from storage.base import Base
class BlockedEmailDomain(Base): # type: ignore
"""Stores blocked email domain patterns.
Supports blocking:
- Exact domains: 'example.com' blocks 'user@example.com'
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
"""
__tablename__ = 'blocked_email_domains'
id = Column(Integer, Identity(), primary_key=True)
domain = Column(String, nullable=False, unique=True)
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
nullable=False,
)

View File

@@ -0,0 +1,45 @@
from dataclasses import dataclass
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
@dataclass
class BlockedEmailDomainStore:
session_maker: sessionmaker
def is_domain_blocked(self, domain: str) -> bool:
"""Check if a domain is blocked by querying the database directly.
This method uses SQL to efficiently check if the domain matches any blocked pattern:
- TLD patterns (e.g., '.us'): checks if domain ends with the pattern
- Full domain patterns (e.g., 'example.com'): checks for exact match or subdomain match
Args:
domain: The extracted domain from the email (e.g., 'example.com' or 'subdomain.example.com')
Returns:
True if the domain is blocked, False otherwise
"""
with self.session_maker() as session:
# SQL query that handles both TLD patterns and full domain patterns
# TLD patterns (starting with '.'): check if domain ends with the pattern
# Full domain patterns: check for exact match or subdomain match
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
query = text("""
SELECT EXISTS(
SELECT 1
FROM blocked_email_domains
WHERE
-- TLD pattern (e.g., '.us') - check if domain ends with it (case-insensitive)
(LOWER(domain) LIKE '.%' AND LOWER(:domain) LIKE '%' || LOWER(domain)) OR
-- Full domain pattern (e.g., 'example.com')
-- Block exact match or subdomains (case-insensitive)
(LOWER(domain) NOT LIKE '.%' AND (
LOWER(:domain) = LOWER(domain) OR
LOWER(:domain) LIKE '%.' || LOWER(domain)
))
)
""")
result = session.execute(query, {'domain': domain}).scalar()
return bool(result)

View File

@@ -546,7 +546,6 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
)
mock_token_manager.disable_keycloak_user = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act
@@ -600,7 +599,6 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
@@ -621,7 +619,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
@pytest.mark.asyncio
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
"""Test keycloak_callback when domain blocking is not active."""
"""Test keycloak_callback when email domain is not blocked."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
@@ -654,7 +652,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = False
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -666,7 +664,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_not_called()
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
mock_token_manager.disable_keycloak_user.assert_not_called()
@@ -705,8 +703,6 @@ async def test_keycloak_callback_missing_email(mock_request):
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = True
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True

View File

@@ -1,33 +1,21 @@
"""Unit tests for DomainBlocker class."""
from unittest.mock import MagicMock
import pytest
from server.auth.domain_blocker import DomainBlocker
@pytest.fixture
def domain_blocker():
"""Create a DomainBlocker instance for testing."""
return DomainBlocker()
def mock_store():
"""Create a mock BlockedEmailDomainStore for testing."""
return MagicMock()
@pytest.mark.parametrize(
'blocked_domains,expected',
[
(['colsch.us', 'other-domain.com'], True),
(['example.com'], True),
([], False),
],
)
def test_is_active(domain_blocker, blocked_domains, expected):
"""Test that is_active returns correct value based on blocked domains configuration."""
# Arrange
domain_blocker.blocked_domains = blocked_domains
# Act
result = domain_blocker.is_active()
# Assert
assert result == expected
@pytest.fixture
def domain_blocker(mock_store):
"""Create a DomainBlocker instance for testing with a mocked store."""
return DomainBlocker(store=mock_store)
@pytest.mark.parametrize(
@@ -69,94 +57,104 @@ def test_extract_domain_invalid_emails(domain_blocker, email, expected):
assert result == expected
def test_is_domain_blocked_when_inactive(domain_blocker):
"""Test that is_domain_blocked returns False when blocking is not active."""
# Arrange
domain_blocker.blocked_domains = []
# Act
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is False
def test_is_domain_blocked_with_none_email(domain_blocker):
def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email is None."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked(None)
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
def test_is_domain_blocked_with_empty_email(domain_blocker):
def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email is empty."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
def test_is_domain_blocked_with_invalid_email(domain_blocker):
def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email format is invalid."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('invalid-email')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when domain is not blocked."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
mock_store.is_domain_blocked.return_value = False
# Act
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.com')
def test_is_domain_blocked_domain_blocked(domain_blocker):
"""Test that is_domain_blocked returns True when domain is in blocked list."""
def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
"""Test that is_domain_blocked returns True when domain is blocked."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
def test_is_domain_blocked_case_insensitive(domain_blocker):
"""Test that is_domain_blocked performs case-insensitive domain matching."""
def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
"""Test that is_domain_blocked performs case-insensitive domain extraction."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
"""Test that is_domain_blocked handles emails with whitespace correctly."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
"""Test that is_domain_blocked correctly checks multiple domains."""
# Arrange
mock_store.is_domain_blocked.side_effect = lambda domain: domain in [
'other-domain.com',
'blocked.org',
]
# Act
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
@@ -167,109 +165,71 @@ def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
assert result1 is True
assert result2 is True
assert result3 is False
assert mock_store.is_domain_blocked.call_count == 3
def test_is_domain_blocked_with_whitespace(domain_blocker):
"""Test that is_domain_blocked handles emails with whitespace correctly."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
# Assert
assert result is True
# ============================================================================
# TLD Blocking Tests (patterns starting with '.')
# ============================================================================
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(domain_blocker):
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
domain_blocker, mock_store
):
"""Test that TLD pattern blocks domains ending with that TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@company.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('company.us')
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(domain_blocker):
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern blocks subdomains with that TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@subdomain.company.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us')
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(domain_blocker):
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern does not block domains with different TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
mock_store.is_domain_blocked.return_value = False
# Act
result = domain_blocker.is_domain_blocked('user@company.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('company.com')
def test_is_domain_blocked_tld_pattern_does_not_block_substring_match(
domain_blocker,
):
"""Test that TLD pattern does not block domains that contain but don't end with the TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
# Act
result = domain_blocker.is_domain_blocked('user@focus.com')
# Assert
assert result is False
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker):
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker, mock_store):
"""Test that TLD pattern matching is case-insensitive."""
# Arrange
domain_blocker.blocked_domains = ['.us']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@COMPANY.US')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('company.us')
def test_is_domain_blocked_multiple_tld_patterns(domain_blocker):
"""Test blocking with multiple TLD patterns."""
# Arrange
domain_blocker.blocked_domains = ['.us', '.vn', '.com']
# Act
result_us = domain_blocker.is_domain_blocked('user@test.us')
result_vn = domain_blocker.is_domain_blocked('user@test.vn')
result_com = domain_blocker.is_domain_blocked('user@test.com')
result_org = domain_blocker.is_domain_blocked('user@test.org')
# Assert
assert result_us is True
assert result_vn is True
assert result_com is True
assert result_org is False
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker):
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker, mock_store):
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
# Arrange
domain_blocker.blocked_domains = ['.co.uk']
mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk')
# Act
result_match = domain_blocker.is_domain_blocked('user@example.co.uk')
@@ -282,81 +242,87 @@ def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker):
assert result_no_match is False
# ============================================================================
# Subdomain Blocking Tests (domain patterns now block subdomains)
# ============================================================================
def test_is_domain_blocked_domain_pattern_blocks_exact_match(domain_blocker):
def test_is_domain_blocked_domain_pattern_blocks_exact_match(
domain_blocker, mock_store
):
"""Test that domain pattern blocks exact domain match."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('example.com')
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker):
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker, mock_store):
"""Test that domain pattern blocks subdomains of that domain."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@subdomain.example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com')
def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
domain_blocker,
domain_blocker, mock_store
):
"""Test that domain pattern blocks multi-level subdomains."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked('user@api.v2.example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com')
def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
domain_blocker,
domain_blocker, mock_store
):
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
mock_store.is_domain_blocked.return_value = False
# Act
result = domain_blocker.is_domain_blocked('user@notexample.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('notexample.com')
def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
domain_blocker,
domain_blocker, mock_store
):
"""Test that domain pattern does not block same domain with different TLD."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
mock_store.is_domain_blocked.return_value = False
# Act
result = domain_blocker.is_domain_blocked('user@example.org')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.org')
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_blocker):
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
domain_blocker, mock_store
):
"""Test that blocking a subdomain also blocks its nested subdomains."""
# Arrange
domain_blocker.blocked_domains = ['api.example.com']
mock_store.is_domain_blocked.side_effect = (
lambda domain: 'api.example.com' in domain
)
# Act
result_exact = domain_blocker.is_domain_blocked('user@api.example.com')
@@ -369,80 +335,10 @@ def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_bloc
assert result_parent is False
# ============================================================================
# Mixed Pattern Tests (TLD + domain patterns together)
# ============================================================================
def test_is_domain_blocked_mixed_patterns_tld_and_domain(domain_blocker):
"""Test blocking with both TLD and domain patterns."""
# Arrange
domain_blocker.blocked_domains = ['.us', 'openhands.dev']
# Act
result_tld = domain_blocker.is_domain_blocked('user@company.us')
result_domain = domain_blocker.is_domain_blocked('user@openhands.dev')
result_subdomain = domain_blocker.is_domain_blocked('user@api.openhands.dev')
result_allowed = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result_tld is True
assert result_domain is True
assert result_subdomain is True
assert result_allowed is False
def test_is_domain_blocked_overlapping_patterns(domain_blocker):
"""Test that overlapping patterns (TLD and specific domain) both work."""
# Arrange
domain_blocker.blocked_domains = ['.us', 'test.us']
# Act
result_specific = domain_blocker.is_domain_blocked('user@test.us')
result_other_us = domain_blocker.is_domain_blocked('user@other.us')
# Assert
assert result_specific is True
assert result_other_us is True
def test_is_domain_blocked_complex_multi_pattern_scenario(domain_blocker):
"""Test complex scenario with multiple TLD and domain patterns."""
# Arrange
domain_blocker.blocked_domains = [
'.us',
'.vn',
'test.com',
'openhands.dev',
]
# Act & Assert
# TLD patterns
assert domain_blocker.is_domain_blocked('user@anything.us') is True
assert domain_blocker.is_domain_blocked('user@company.vn') is True
# Domain patterns (exact)
assert domain_blocker.is_domain_blocked('user@test.com') is True
assert domain_blocker.is_domain_blocked('user@openhands.dev') is True
# Domain patterns (subdomains)
assert domain_blocker.is_domain_blocked('user@api.test.com') is True
assert domain_blocker.is_domain_blocked('user@staging.openhands.dev') is True
# Not blocked
assert domain_blocker.is_domain_blocked('user@allowed.com') is False
assert domain_blocker.is_domain_blocked('user@example.org') is False
# ============================================================================
# Edge Case Tests
# ============================================================================
def test_is_domain_blocked_domain_with_hyphens(domain_blocker):
def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
"""Test that domain patterns work with hyphenated domains."""
# Arrange
domain_blocker.blocked_domains = ['my-company.com']
mock_store.is_domain_blocked.return_value = True
# Act
result_exact = domain_blocker.is_domain_blocked('user@my-company.com')
@@ -451,12 +347,13 @@ def test_is_domain_blocked_domain_with_hyphens(domain_blocker):
# Assert
assert result_exact is True
assert result_subdomain is True
assert mock_store.is_domain_blocked.call_count == 2
def test_is_domain_blocked_domain_with_numbers(domain_blocker):
def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
"""Test that domain patterns work with numeric domains."""
# Arrange
domain_blocker.blocked_domains = ['test123.com']
mock_store.is_domain_blocked.return_value = True
# Act
result_exact = domain_blocker.is_domain_blocked('user@test123.com')
@@ -465,24 +362,13 @@ def test_is_domain_blocked_domain_with_numbers(domain_blocker):
# Assert
assert result_exact is True
assert result_subdomain is True
assert mock_store.is_domain_blocked.call_count == 2
def test_is_domain_blocked_short_tld(domain_blocker):
"""Test that short TLD patterns work correctly."""
# Arrange
domain_blocker.blocked_domains = ['.io']
# Act
result = domain_blocker.is_domain_blocked('user@company.io')
# Assert
assert result is True
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker):
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
"""Test that blocking works with very long subdomain chains."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
mock_store.is_domain_blocked.return_value = True
# Act
result = domain_blocker.is_domain_blocked(
@@ -491,3 +377,19 @@ def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker):
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with(
'level4.level3.level2.level1.example.com'
)
def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when store raises an exception."""
# Arrange
mock_store.is_domain_blocked.side_effect = Exception('Database connection error')
# Act
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.com')

View File

@@ -673,7 +673,6 @@ async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act & Assert
@@ -703,7 +702,6 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
# Act
@@ -720,7 +718,7 @@ async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
# Arrange
access_payload = {
'sub': 'test_user_id',
@@ -737,7 +735,7 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = False
mock_domain_blocker.is_domain_blocked.return_value = False
# Act
result = await saas_user_auth_from_signed_token(signed_token)
@@ -745,4 +743,4 @@ async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_co
# Assert
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
mock_domain_blocker.is_domain_blocked.assert_not_called()
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')