mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -6,30 +6,13 @@ from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_store import OrgStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
async def check_byor_export_enabled(user_id: str) -> bool:
|
||||
"""Check if BYOR export is enabled for the user's current org.
|
||||
|
||||
Returns True if the user's current org has byor_export_enabled set to True.
|
||||
Returns False if the user is not found, has no current org, or the flag is False.
|
||||
"""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
org = OrgStore.get_org_by_id(user.current_org_id)
|
||||
if not org:
|
||||
return False
|
||||
|
||||
return org.byor_export_enabled
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
@@ -173,7 +156,7 @@ class ByorPermittedResponse(BaseModel):
|
||||
async def check_byor_permitted(user_id: str = Depends(get_user_id)):
|
||||
"""Check if BYOR key export is permitted for the user's current org."""
|
||||
try:
|
||||
permitted = await check_byor_export_enabled(user_id)
|
||||
permitted = await OrgService.check_byor_export_enabled(user_id)
|
||||
return {'permitted': permitted}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@@ -295,7 +278,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
"""
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await check_byor_export_enabled(user_id):
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
@@ -364,7 +347,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await check_byor_export_enabled(user_id):
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
|
||||
@@ -17,6 +17,7 @@ from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_store import OrgStore
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -259,6 +260,9 @@ async def success_callback(session_id: str, request: Request):
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Enable BYOR export for the org now that they've purchased credits
|
||||
OrgStore.update_org(user.current_org_id, {'byor_export_enabled': True})
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
|
||||
@@ -842,3 +842,26 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def check_byor_export_enabled(user_id: str) -> bool:
|
||||
"""Check if BYOR export is enabled for the user's current org.
|
||||
|
||||
Returns True if the user's current org has byor_export_enabled set to True.
|
||||
Returns False if the user is not found, has no current org, or the flag is False.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
|
||||
Returns:
|
||||
bool: True if BYOR export is enabled, False otherwise
|
||||
"""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
org = OrgStore.get_org_by_id(user.current_org_id)
|
||||
if not org:
|
||||
return False
|
||||
|
||||
return org.byor_export_enabled
|
||||
|
||||
@@ -6,6 +6,7 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
@@ -182,7 +183,7 @@ class TestGetLlmApiKeyForByor:
|
||||
"""Test the get_llm_api_key_for_byor endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
@@ -209,7 +210,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
@@ -233,7 +234,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_verify_key.assert_called_once_with(existing_key, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -273,7 +274,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -311,7 +312,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_key_generation_failure_raises_exception(
|
||||
@@ -332,7 +333,7 @@ class TestGetLlmApiKeyForByor:
|
||||
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_database_error_raises_exception(
|
||||
self, mock_get_key, mock_check_enabled
|
||||
@@ -351,7 +352,7 @@ class TestGetLlmApiKeyForByor:
|
||||
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.check_byor_export_enabled')
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
|
||||
"""Test that when BYOR export is disabled, 402 is returned."""
|
||||
# Arrange
|
||||
@@ -460,3 +461,52 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCheckByorPermitted:
|
||||
"""Test the check_byor_permitted endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_permitted_when_enabled(self, mock_check_enabled):
|
||||
"""Test that permitted=True is returned when BYOR export is enabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'permitted': True}
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_not_permitted_when_disabled(self, mock_check_enabled):
|
||||
"""Test that permitted=False is returned when BYOR export is disabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'permitted': False}
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_error_raises_500(self, mock_check_enabled):
|
||||
"""Test that an exception raises 500 error."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.side_effect = Exception('Database error')
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_byor_permitted(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
@@ -317,6 +317,7 @@ async def test_success_callback_success():
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
patch('server.routes.billing.OrgStore.update_org') as mock_update_org,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
@@ -340,6 +341,12 @@ async def test_success_callback_success():
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
)
|
||||
|
||||
# Verify BYOR export is enabled for the org
|
||||
mock_update_org.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
{'byor_export_enabled': True},
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
|
||||
@@ -1657,3 +1657,147 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker):
|
||||
assert result.contact_name == 'Jane Doe'
|
||||
assert result.conversation_expiration == 60
|
||||
assert result.enable_proactive_conversation_starters is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_true_when_enabled():
|
||||
"""
|
||||
GIVEN: User has current_org with byor_export_enabled=True
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns True
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.byor_export_enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_disabled():
|
||||
"""
|
||||
GIVEN: User has current_org with byor_export_enabled=False
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.byor_export_enabled = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_user_not_found():
|
||||
"""
|
||||
GIVEN: User does not exist
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'nonexistent-user'
|
||||
|
||||
with patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_no_current_org():
|
||||
"""
|
||||
GIVEN: User exists but has no current_org_id
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = None
|
||||
|
||||
with patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_org_not_found():
|
||||
"""
|
||||
GIVEN: User has current_org_id but org does not exist
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
Reference in New Issue
Block a user