refactor: Move check_byor_export_enabled to OrgService and add tests (PR #12753 followup) (#12782)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-02-06 10:03:03 -07:00
committed by GitHub
parent 872f41e3c0
commit 1d1c0925b5
6 changed files with 239 additions and 28 deletions

View File

@@ -6,30 +6,13 @@ from storage.api_key_store import ApiKeyStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.org_store import OrgStore
from storage.org_service import OrgService
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
async def check_byor_export_enabled(user_id: str) -> bool:
"""Check if BYOR export is enabled for the user's current org.
Returns True if the user's current org has byor_export_enabled set to True.
Returns False if the user is not found, has no current org, or the flag is False.
"""
user = await UserStore.get_user_by_id_async(user_id)
if not user or not user.current_org_id:
return False
org = OrgStore.get_org_by_id(user.current_org_id)
if not org:
return False
return org.byor_export_enabled
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
@@ -173,7 +156,7 @@ class ByorPermittedResponse(BaseModel):
async def check_byor_permitted(user_id: str = Depends(get_user_id)):
"""Check if BYOR key export is permitted for the user's current org."""
try:
permitted = await check_byor_export_enabled(user_id)
permitted = await OrgService.check_byor_export_enabled(user_id)
return {'permitted': permitted}
except Exception as e:
logger.exception(
@@ -295,7 +278,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
"""
try:
# Check if BYOR export is enabled for the user's org
if not await check_byor_export_enabled(user_id):
if not await OrgService.check_byor_export_enabled(user_id):
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
@@ -364,7 +347,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
try:
# Check if BYOR export is enabled for the user's org
if not await check_byor_export_enabled(user_id):
if not await OrgService.check_byor_export_enabled(user_id):
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',

View File

@@ -17,6 +17,7 @@ from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org_store import OrgStore
from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
@@ -259,6 +260,9 @@ async def success_callback(session_id: str, request: Request):
str(user.current_org_id), new_max_budget
)
# Enable BYOR export for the org now that they've purchased credits
OrgStore.update_org(user.current_org_id, {'byor_export_enabled': True})
# Store transaction status
billing_session.status = 'completed'
billing_session.price = add_credits

View File

@@ -842,3 +842,26 @@ class OrgService:
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
@staticmethod
async def check_byor_export_enabled(user_id: str) -> bool:
"""Check if BYOR export is enabled for the user's current org.
Returns True if the user's current org has byor_export_enabled set to True.
Returns False if the user is not found, has no current org, or the flag is False.
Args:
user_id: User ID to check
Returns:
bool: True if BYOR export is enabled, False otherwise
"""
user = await UserStore.get_user_by_id_async(user_id)
if not user or not user.current_org_id:
return False
org = OrgStore.get_org_by_id(user.current_org_id)
if not org:
return False
return org.byor_export_enabled

View File

@@ -6,6 +6,7 @@ import httpx
import pytest
from fastapi import HTTPException
from server.routes.api_keys import (
check_byor_permitted,
delete_byor_key_from_litellm,
get_llm_api_key_for_byor,
)
@@ -182,7 +183,7 @@ class TestGetLlmApiKeyForByor:
"""Test the get_llm_api_key_for_byor endpoint."""
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
@@ -209,7 +210,7 @@ class TestGetLlmApiKeyForByor:
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_valid_key_in_database_returns_key(
@@ -233,7 +234,7 @@ class TestGetLlmApiKeyForByor:
mock_verify_key.assert_called_once_with(existing_key, user_id)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@@ -273,7 +274,7 @@ class TestGetLlmApiKeyForByor:
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@@ -311,7 +312,7 @@ class TestGetLlmApiKeyForByor:
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_key_generation_failure_raises_exception(
@@ -332,7 +333,7 @@ class TestGetLlmApiKeyForByor:
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_database_error_raises_exception(
self, mock_get_key, mock_check_enabled
@@ -351,7 +352,7 @@ class TestGetLlmApiKeyForByor:
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.api_keys.check_byor_export_enabled')
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
"""Test that when BYOR export is disabled, 402 is returned."""
# Arrange
@@ -460,3 +461,52 @@ class TestDeleteByorKeyFromLitellm:
# Assert
assert result is False
class TestCheckByorPermitted:
"""Test the check_byor_permitted endpoint."""
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_permitted_when_enabled(self, mock_check_enabled):
"""Test that permitted=True is returned when BYOR export is enabled."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = True
# Act
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == {'permitted': True}
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_not_permitted_when_disabled(self, mock_check_enabled):
"""Test that permitted=False is returned when BYOR export is disabled."""
# Arrange
user_id = 'user-123'
mock_check_enabled.return_value = False
# Act
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == {'permitted': False}
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio
@patch('storage.org_service.OrgService.check_byor_export_enabled')
async def test_error_raises_500(self, mock_check_enabled):
"""Test that an exception raises 500 error."""
# Arrange
user_id = 'user-123'
mock_check_enabled.side_effect = Exception('Database error')
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_byor_permitted(user_id=user_id)
assert exc_info.value.status_code == 500
assert 'Failed to check BYOR export permission' in exc_info.value.detail

View File

@@ -317,6 +317,7 @@ async def test_success_callback_success():
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
) as mock_update_budget,
patch('server.routes.billing.OrgStore.update_org') as mock_update_org,
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
@@ -340,6 +341,12 @@ async def test_success_callback_success():
125.0, # 100 + (25.00 from Stripe)
)
# Verify BYOR export is enabled for the org
mock_update_org.assert_called_once_with(
'mock_org_id',
{'byor_export_enabled': True},
)
# Verify database updates
assert mock_billing_session.status == 'completed'
assert mock_billing_session.price == 25.0

View File

@@ -1657,3 +1657,147 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker):
assert result.contact_name == 'Jane Doe'
assert result.conversation_expiration == 60
assert result.enable_proactive_conversation_starters is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_true_when_enabled():
"""
GIVEN: User has current_org with byor_export_enabled=True
WHEN: check_byor_export_enabled is called
THEN: Returns True
"""
# Arrange
user_id = 'test-user-123'
org_id = uuid.uuid4()
mock_user = MagicMock()
mock_user.current_org_id = org_id
mock_org = MagicMock()
mock_org.byor_export_enabled = True
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is True
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_disabled():
"""
GIVEN: User has current_org with byor_export_enabled=False
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'test-user-123'
org_id = uuid.uuid4()
mock_user = MagicMock()
mock_user.current_org_id = org_id
mock_org = MagicMock()
mock_org.byor_export_enabled = False
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_user_not_found():
"""
GIVEN: User does not exist
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'nonexistent-user'
with patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=None),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_no_current_org():
"""
GIVEN: User exists but has no current_org_id
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'test-user-123'
mock_user = MagicMock()
mock_user.current_org_id = None
with patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_byor_export_enabled_returns_false_when_org_not_found():
"""
GIVEN: User has current_org_id but org does not exist
WHEN: check_byor_export_enabled is called
THEN: Returns False
"""
# Arrange
user_id = 'test-user-123'
org_id = uuid.uuid4()
mock_user = MagicMock()
mock_user.current_org_id = org_id
with (
patch(
'storage.org_service.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
),
patch(
'storage.org_service.OrgStore.get_org_by_id',
return_value=None,
),
):
# Act
result = await OrgService.check_byor_export_enabled(user_id)
# Assert
assert result is False