From 1d1c0925b5a50f1abb4aaa60bd5ec9ac07b11ca2 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Fri, 6 Feb 2026 10:03:03 -0700 Subject: [PATCH] refactor: Move check_byor_export_enabled to OrgService and add tests (PR #12753 followup) (#12782) Co-authored-by: openhands --- enterprise/server/routes/api_keys.py | 25 +-- enterprise/server/routes/billing.py | 4 + enterprise/storage/org_service.py | 23 +++ .../tests/unit/server/routes/test_api_keys.py | 64 +++++++- enterprise/tests/unit/test_billing.py | 7 + enterprise/tests/unit/test_org_service.py | 144 ++++++++++++++++++ 6 files changed, 239 insertions(+), 28 deletions(-) diff --git a/enterprise/server/routes/api_keys.py b/enterprise/server/routes/api_keys.py index 711937cf2e..aea7572af2 100644 --- a/enterprise/server/routes/api_keys.py +++ b/enterprise/server/routes/api_keys.py @@ -6,30 +6,13 @@ from storage.api_key_store import ApiKeyStore from storage.lite_llm_manager import LiteLlmManager from storage.org_member import OrgMember from storage.org_member_store import OrgMemberStore -from storage.org_store import OrgStore +from storage.org_service import OrgService from storage.user_store import UserStore from openhands.core.logger import openhands_logger as logger from openhands.server.user_auth import get_user_id -async def check_byor_export_enabled(user_id: str) -> bool: - """Check if BYOR export is enabled for the user's current org. - - Returns True if the user's current org has byor_export_enabled set to True. - Returns False if the user is not found, has no current org, or the flag is False. - """ - user = await UserStore.get_user_by_id_async(user_id) - if not user or not user.current_org_id: - return False - - org = OrgStore.get_org_by_id(user.current_org_id) - if not org: - return False - - return org.byor_export_enabled - - # Helper functions for BYOR API key management async def get_byor_key_from_db(user_id: str) -> str | None: """Get the BYOR key from the database for a user.""" @@ -173,7 +156,7 @@ class ByorPermittedResponse(BaseModel): async def check_byor_permitted(user_id: str = Depends(get_user_id)): """Check if BYOR key export is permitted for the user's current org.""" try: - permitted = await check_byor_export_enabled(user_id) + permitted = await OrgService.check_byor_export_enabled(user_id) return {'permitted': permitted} except Exception as e: logger.exception( @@ -295,7 +278,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)): """ try: # Check if BYOR export is enabled for the user's org - if not await check_byor_export_enabled(user_id): + if not await OrgService.check_byor_export_enabled(user_id): raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail='BYOR key export is not enabled. Purchase credits to enable this feature.', @@ -364,7 +347,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)): try: # Check if BYOR export is enabled for the user's org - if not await check_byor_export_enabled(user_id): + if not await OrgService.check_byor_export_enabled(user_id): raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail='BYOR key export is not enabled. Purchase credits to enable this feature.', diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index 19e9dfa17c..f99d90f710 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -17,6 +17,7 @@ from starlette.datastructures import URL from storage.billing_session import BillingSession from storage.database import session_maker from storage.lite_llm_manager import LiteLlmManager +from storage.org_store import OrgStore from storage.subscription_access import SubscriptionAccess from storage.user_store import UserStore @@ -259,6 +260,9 @@ async def success_callback(session_id: str, request: Request): str(user.current_org_id), new_max_budget ) + # Enable BYOR export for the org now that they've purchased credits + OrgStore.update_org(user.current_org_id, {'byor_export_enabled': True}) + # Store transaction status billing_session.status = 'completed' billing_session.price = add_credits diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index 8669ee498c..af25c05a91 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -842,3 +842,26 @@ class OrgService: extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)}, ) raise OrgDatabaseError(f'Failed to delete organization: {str(e)}') + + @staticmethod + async def check_byor_export_enabled(user_id: str) -> bool: + """Check if BYOR export is enabled for the user's current org. + + Returns True if the user's current org has byor_export_enabled set to True. + Returns False if the user is not found, has no current org, or the flag is False. + + Args: + user_id: User ID to check + + Returns: + bool: True if BYOR export is enabled, False otherwise + """ + user = await UserStore.get_user_by_id_async(user_id) + if not user or not user.current_org_id: + return False + + org = OrgStore.get_org_by_id(user.current_org_id) + if not org: + return False + + return org.byor_export_enabled diff --git a/enterprise/tests/unit/server/routes/test_api_keys.py b/enterprise/tests/unit/server/routes/test_api_keys.py index 703a0cf467..3629a8c44f 100644 --- a/enterprise/tests/unit/server/routes/test_api_keys.py +++ b/enterprise/tests/unit/server/routes/test_api_keys.py @@ -6,6 +6,7 @@ import httpx import pytest from fastapi import HTTPException from server.routes.api_keys import ( + check_byor_permitted, delete_byor_key_from_litellm, get_llm_api_key_for_byor, ) @@ -182,7 +183,7 @@ class TestGetLlmApiKeyForByor: """Test the get_llm_api_key_for_byor endpoint.""" @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') @patch('server.routes.api_keys.store_byor_key_in_db') @patch('server.routes.api_keys.generate_byor_key') @patch('server.routes.api_keys.get_byor_key_from_db') @@ -209,7 +210,7 @@ class TestGetLlmApiKeyForByor: mock_store_key.assert_called_once_with(user_id, new_key) @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') @patch('storage.lite_llm_manager.LiteLlmManager.verify_key') @patch('server.routes.api_keys.get_byor_key_from_db') async def test_valid_key_in_database_returns_key( @@ -233,7 +234,7 @@ class TestGetLlmApiKeyForByor: mock_verify_key.assert_called_once_with(existing_key, user_id) @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') @patch('server.routes.api_keys.store_byor_key_in_db') @patch('server.routes.api_keys.generate_byor_key') @patch('server.routes.api_keys.delete_byor_key_from_litellm') @@ -273,7 +274,7 @@ class TestGetLlmApiKeyForByor: mock_store_key.assert_called_once_with(user_id, new_key) @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') @patch('server.routes.api_keys.store_byor_key_in_db') @patch('server.routes.api_keys.generate_byor_key') @patch('server.routes.api_keys.delete_byor_key_from_litellm') @@ -311,7 +312,7 @@ class TestGetLlmApiKeyForByor: mock_store_key.assert_called_once_with(user_id, new_key) @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') @patch('server.routes.api_keys.generate_byor_key') @patch('server.routes.api_keys.get_byor_key_from_db') async def test_key_generation_failure_raises_exception( @@ -332,7 +333,7 @@ class TestGetLlmApiKeyForByor: assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') @patch('server.routes.api_keys.get_byor_key_from_db') async def test_database_error_raises_exception( self, mock_get_key, mock_check_enabled @@ -351,7 +352,7 @@ class TestGetLlmApiKeyForByor: assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail @pytest.mark.asyncio - @patch('server.routes.api_keys.check_byor_export_enabled') + @patch('storage.org_service.OrgService.check_byor_export_enabled') async def test_byor_export_disabled_returns_402(self, mock_check_enabled): """Test that when BYOR export is disabled, 402 is returned.""" # Arrange @@ -460,3 +461,52 @@ class TestDeleteByorKeyFromLitellm: # Assert assert result is False + + +class TestCheckByorPermitted: + """Test the check_byor_permitted endpoint.""" + + @pytest.mark.asyncio + @patch('storage.org_service.OrgService.check_byor_export_enabled') + async def test_permitted_when_enabled(self, mock_check_enabled): + """Test that permitted=True is returned when BYOR export is enabled.""" + # Arrange + user_id = 'user-123' + mock_check_enabled.return_value = True + + # Act + result = await check_byor_permitted(user_id=user_id) + + # Assert + assert result == {'permitted': True} + mock_check_enabled.assert_called_once_with(user_id) + + @pytest.mark.asyncio + @patch('storage.org_service.OrgService.check_byor_export_enabled') + async def test_not_permitted_when_disabled(self, mock_check_enabled): + """Test that permitted=False is returned when BYOR export is disabled.""" + # Arrange + user_id = 'user-123' + mock_check_enabled.return_value = False + + # Act + result = await check_byor_permitted(user_id=user_id) + + # Assert + assert result == {'permitted': False} + mock_check_enabled.assert_called_once_with(user_id) + + @pytest.mark.asyncio + @patch('storage.org_service.OrgService.check_byor_export_enabled') + async def test_error_raises_500(self, mock_check_enabled): + """Test that an exception raises 500 error.""" + # Arrange + user_id = 'user-123' + mock_check_enabled.side_effect = Exception('Database error') + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await check_byor_permitted(user_id=user_id) + + assert exc_info.value.status_code == 500 + assert 'Failed to check BYOR export permission' in exc_info.value.detail diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index c7259e9a05..ca54cd788b 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -317,6 +317,7 @@ async def test_success_callback_success(): patch( 'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget' ) as mock_update_budget, + patch('server.routes.billing.OrgStore.update_org') as mock_update_org, ): mock_db_session = MagicMock() mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session @@ -340,6 +341,12 @@ async def test_success_callback_success(): 125.0, # 100 + (25.00 from Stripe) ) + # Verify BYOR export is enabled for the org + mock_update_org.assert_called_once_with( + 'mock_org_id', + {'byor_export_enabled': True}, + ) + # Verify database updates assert mock_billing_session.status == 'completed' assert mock_billing_session.price == 25.0 diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index 388df70b27..fdc8b598bb 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -1657,3 +1657,147 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker): assert result.contact_name == 'Jane Doe' assert result.conversation_expiration == 60 assert result.enable_proactive_conversation_starters is False + + +@pytest.mark.asyncio +async def test_check_byor_export_enabled_returns_true_when_enabled(): + """ + GIVEN: User has current_org with byor_export_enabled=True + WHEN: check_byor_export_enabled is called + THEN: Returns True + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + + mock_user = MagicMock() + mock_user.current_org_id = org_id + + mock_org = MagicMock() + mock_org.byor_export_enabled = True + + with ( + patch( + 'storage.org_service.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + return_value=mock_org, + ), + ): + # Act + result = await OrgService.check_byor_export_enabled(user_id) + + # Assert + assert result is True + + +@pytest.mark.asyncio +async def test_check_byor_export_enabled_returns_false_when_disabled(): + """ + GIVEN: User has current_org with byor_export_enabled=False + WHEN: check_byor_export_enabled is called + THEN: Returns False + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + + mock_user = MagicMock() + mock_user.current_org_id = org_id + + mock_org = MagicMock() + mock_org.byor_export_enabled = False + + with ( + patch( + 'storage.org_service.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + return_value=mock_org, + ), + ): + # Act + result = await OrgService.check_byor_export_enabled(user_id) + + # Assert + assert result is False + + +@pytest.mark.asyncio +async def test_check_byor_export_enabled_returns_false_when_user_not_found(): + """ + GIVEN: User does not exist + WHEN: check_byor_export_enabled is called + THEN: Returns False + """ + # Arrange + user_id = 'nonexistent-user' + + with patch( + 'storage.org_service.UserStore.get_user_by_id_async', + AsyncMock(return_value=None), + ): + # Act + result = await OrgService.check_byor_export_enabled(user_id) + + # Assert + assert result is False + + +@pytest.mark.asyncio +async def test_check_byor_export_enabled_returns_false_when_no_current_org(): + """ + GIVEN: User exists but has no current_org_id + WHEN: check_byor_export_enabled is called + THEN: Returns False + """ + # Arrange + user_id = 'test-user-123' + + mock_user = MagicMock() + mock_user.current_org_id = None + + with patch( + 'storage.org_service.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ): + # Act + result = await OrgService.check_byor_export_enabled(user_id) + + # Assert + assert result is False + + +@pytest.mark.asyncio +async def test_check_byor_export_enabled_returns_false_when_org_not_found(): + """ + GIVEN: User has current_org_id but org does not exist + WHEN: check_byor_export_enabled is called + THEN: Returns False + """ + # Arrange + user_id = 'test-user-123' + org_id = uuid.uuid4() + + mock_user = MagicMock() + mock_user.current_org_id = org_id + + with ( + patch( + 'storage.org_service.UserStore.get_user_by_id_async', + AsyncMock(return_value=mock_user), + ), + patch( + 'storage.org_service.OrgStore.get_org_by_id', + return_value=None, + ), + ): + # Act + result = await OrgService.check_byor_export_enabled(user_id) + + # Assert + assert result is False