From 8927ac2230666a47814980d28690641da74636e6 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Fri, 20 Feb 2026 01:34:53 +0700 Subject: [PATCH] fix(backend): organization members now see correct shared credit balance (#12942) --- enterprise/server/routes/billing.py | 10 +- enterprise/storage/lite_llm_manager.py | 76 +++++- enterprise/storage/org_service.py | 5 +- enterprise/tests/unit/test_billing.py | 8 +- .../tests/unit/test_lite_llm_manager.py | 216 +++++++++++++++--- enterprise/tests/unit/test_org_service.py | 2 +- 6 files changed, 264 insertions(+), 53 deletions(-) diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index 781d542189..fccbdd3f1b 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -93,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse user_team_info = await LiteLlmManager.get_user_team_info( user_id, str(user.current_org_id) ) - # Update to use calculate_credits - spend = user_team_info.get('spend', 0) - max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0) + max_budget, spend = LiteLlmManager.get_budget_from_team_info( + user_team_info, user_id, str(user.current_org_id) + ) credits = max(max_budget - spend, 0) return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits))) @@ -249,8 +249,8 @@ async def success_callback(session_id: str, request: Request): ) amount_subtotal = stripe_session.amount_subtotal or 0 add_credits = amount_subtotal / 100 - max_budget = (user_team_info.get('litellm_budget_table') or {}).get( - 'max_budget', 0 + max_budget, _ = LiteLlmManager.get_budget_from_team_info( + user_team_info, billing_session.user_id, str(user.current_org_id) ) org = session.query(Org).filter(Org.id == user.current_org_id).first() diff --git a/enterprise/storage/lite_llm_manager.py b/enterprise/storage/lite_llm_manager.py index e636e39759..8cf8b4e998 100644 --- a/enterprise/storage/lite_llm_manager.py +++ b/enterprise/storage/lite_llm_manager.py @@ -43,6 +43,34 @@ def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str: class LiteLlmManager: """Manage LiteLLM interactions.""" + @staticmethod + def get_budget_from_team_info( + user_team_info: dict | None, user_id: str, org_id: str + ) -> tuple[float, float]: + """Extract max_budget and spend from user team info. + + For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget. + For team orgs, uses max_budget_in_team (populated by get_user_team_info). + + Args: + user_team_info: The response from get_user_team_info + user_id: The user's ID + org_id: The organization's ID + + Returns: + Tuple of (max_budget, spend) + """ + if not user_team_info: + return 0, 0 + spend = user_team_info.get('spend', 0) + if user_id == org_id: + max_budget = (user_team_info.get('litellm_budget_table') or {}).get( + 'max_budget', 0 + ) + else: + max_budget = user_team_info.get('max_budget_in_team') or 0 + return max_budget, spend + @staticmethod async def create_entries( org_id: str, @@ -71,8 +99,34 @@ class LiteLlmManager: 'x-goog-api-key': LITE_LLM_API_KEY, } ) as client: - # New users start with $0 budget - they must purchase credits - await LiteLlmManager._create_team(client, keycloak_user_id, org_id, 0) + # Check if team already exists and get its budget + # New users joining existing orgs should inherit the team's budget + team_budget = 0.0 + try: + existing_team = await LiteLlmManager._get_team(client, org_id) + if existing_team: + team_info = existing_team.get('team_info', {}) + team_budget = team_info.get('max_budget', 0.0) or 0.0 + logger.info( + 'LiteLlmManager:create_entries:existing_team_budget', + extra={ + 'org_id': org_id, + 'user_id': keycloak_user_id, + 'team_budget': team_budget, + }, + ) + except httpx.HTTPStatusError as e: + # Team doesn't exist yet (404) - this is expected for first user + if e.response.status_code != 404: + raise + logger.info( + 'LiteLlmManager:create_entries:no_existing_team', + extra={'org_id': org_id, 'user_id': keycloak_user_id}, + ) + + await LiteLlmManager._create_team( + client, keycloak_user_id, org_id, team_budget + ) if create_user: await LiteLlmManager._create_user( @@ -80,7 +134,7 @@ class LiteLlmManager: ) await LiteLlmManager._add_user_to_team( - client, keycloak_user_id, org_id, 0 + client, keycloak_user_id, org_id, team_budget ) key = await LiteLlmManager._generate_key( @@ -892,21 +946,31 @@ class LiteLlmManager: if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None: logger.warning('LiteLLM API configuration not found') return None - team_info = await LiteLlmManager._get_team(client, team_id) - if not team_info: + team_response = await LiteLlmManager._get_team(client, team_id) + if not team_response: return None # Filter team_memberships based on team_id and keycloak_user_id user_membership = next( ( membership - for membership in team_info.get('team_memberships', []) + for membership in team_response.get('team_memberships', []) if membership.get('user_id') == keycloak_user_id and membership.get('team_id') == team_id ), None, ) + if not user_membership: + return None + + # For team orgs (user_id != team_id), include team-level budget info + # The team's max_budget and spend are shared across all members + if keycloak_user_id != team_id: + team_info = team_response.get('team_info', {}) + user_membership['max_budget_in_team'] = team_info.get('max_budget') + user_membership['spend'] = team_info.get('spend', 0) + return user_membership @staticmethod diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index 144d636a83..00fdf4443e 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -656,10 +656,9 @@ class OrgService: ) return None - max_budget = (user_team_info.get('litellm_budget_table') or {}).get( - 'max_budget', 0 + max_budget, spend = LiteLlmManager.get_budget_from_team_info( + user_team_info, user_id, str(org_id) ) - spend = user_team_info.get('spend', 0) credits = max(max_budget - spend, 0) logger.debug( diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index fd28f4b644..7350b851c5 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -101,7 +101,7 @@ async def test_get_credits_success(): json={ 'user_info': { 'spend': 25.50, - 'litellm_budget_table': {'max_budget': 100.00}, + 'max_budget_in_team': 100.00, } }, request=MagicMock(), @@ -121,7 +121,7 @@ async def test_get_credits_success(): 'storage.lite_llm_manager.LiteLlmManager.get_user_team_info', return_value={ 'spend': 25.50, - 'litellm_budget_table': {'max_budget': 100.00}, + 'max_budget_in_team': 100.00, }, ), ): @@ -313,7 +313,7 @@ async def test_success_callback_success(): 'storage.lite_llm_manager.LiteLlmManager.get_user_team_info', return_value={ 'spend': 25.50, - 'litellm_budget_table': {'max_budget': 100.00}, + 'max_budget_in_team': 100.00, }, ), patch( @@ -430,7 +430,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(): 'storage.lite_llm_manager.LiteLlmManager.get_user_team_info', return_value={ 'spend': 0, - 'litellm_budget_table': {'max_budget': 0}, + 'max_budget_in_team': 0, }, ), patch( diff --git a/enterprise/tests/unit/test_lite_llm_manager.py b/enterprise/tests/unit/test_lite_llm_manager.py index c89a89d6ba..cac0b37e23 100644 --- a/enterprise/tests/unit/test_lite_llm_manager.py +++ b/enterprise/tests/unit/test_lite_llm_manager.py @@ -142,44 +142,192 @@ class TestLiteLlmManager: @pytest.mark.asyncio async def test_create_entries_cloud_deployment(self, mock_settings, mock_response): """Test create_entries in cloud deployment mode.""" - with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}): - with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'): - with patch( - 'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com' - ): - with patch( - 'storage.lite_llm_manager.TokenManager' - ) as mock_token_manager: - mock_token_manager.return_value.get_user_info_from_user_id = ( - AsyncMock(return_value={'email': 'test@example.com'}) - ) + mock_404_response = MagicMock() + mock_404_response.status_code = 404 + mock_404_response.is_success = False - with patch('httpx.AsyncClient') as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = ( - mock_client - ) - mock_client.post.return_value = mock_response + mock_token_manager = MagicMock() + mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( + return_value={'email': 'test@example.com'} + ) - result = await LiteLlmManager.create_entries( - 'test-org-id', - 'test-user-id', - mock_settings, - create_user=False, - ) + mock_client = AsyncMock() + mock_client.get.return_value = mock_404_response + mock_client.get.return_value.raise_for_status.side_effect = ( + httpx.HTTPStatusError( + message='Not Found', request=MagicMock(), response=mock_404_response + ) + ) + mock_client.post.return_value = mock_response - assert result is not None - assert result.agent == 'CodeActAgent' - assert result.llm_model == get_default_litellm_model() - assert ( - result.llm_api_key.get_secret_value() == 'test-api-key' - ) - assert result.llm_base_url == 'http://test.com' + mock_client_class = MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client - # Verify API calls were made - assert ( - mock_client.post.call_count == 3 - ) # create_team, create_user, add_user_to_team, generate_key + with ( + patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}), + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'), + patch('storage.lite_llm_manager.TokenManager', mock_token_manager), + patch('httpx.AsyncClient', mock_client_class), + ): + result = await LiteLlmManager.create_entries( + 'test-org-id', 'test-user-id', mock_settings, create_user=False + ) + + assert result is not None + assert result.agent == 'CodeActAgent' + assert result.llm_model == get_default_litellm_model() + assert result.llm_api_key.get_secret_value() == 'test-api-key' + assert result.llm_base_url == 'http://test.com' + + # Verify API calls were made (get_team + 3 posts) + assert mock_client.get.call_count == 1 # get_team + assert ( + mock_client.post.call_count == 3 + ) # create_team, add_user_to_team, generate_key + + @pytest.mark.asyncio + async def test_create_entries_inherits_existing_team_budget( + self, mock_settings, mock_response + ): + """Test that create_entries inherits budget from existing team.""" + mock_team_response = MagicMock() + mock_team_response.is_success = True + mock_team_response.status_code = 200 + mock_team_response.json.return_value = { + 'team_info': {'max_budget': 30.0, 'spend': 5.0}, + 'team_memberships': [], + } + mock_team_response.raise_for_status = MagicMock() + + mock_token_manager = MagicMock() + mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( + return_value={'email': 'test@example.com'} + ) + + mock_client = AsyncMock() + mock_client.get.return_value = mock_team_response + mock_client.post.return_value = mock_response + + mock_client_class = MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + with ( + patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}), + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'), + patch('storage.lite_llm_manager.TokenManager', mock_token_manager), + patch('httpx.AsyncClient', mock_client_class), + ): + result = await LiteLlmManager.create_entries( + 'test-org-id', 'test-user-id', mock_settings, create_user=False + ) + + assert result is not None + + # Verify _get_team was called first + mock_client.get.assert_called_once() + get_call_url = mock_client.get.call_args[0][0] + assert 'team/info' in get_call_url + assert 'test-org-id' in get_call_url + + # Verify _create_team was called with inherited budget (30.0) + create_team_call = mock_client.post.call_args_list[0] + assert 'team/new' in create_team_call[0][0] + assert create_team_call[1]['json']['max_budget'] == 30.0 + + # Verify _add_user_to_team was called with inherited budget (30.0) + add_user_call = mock_client.post.call_args_list[1] + assert 'team/member_add' in add_user_call[0][0] + assert add_user_call[1]['json']['max_budget_in_team'] == 30.0 + + @pytest.mark.asyncio + async def test_create_entries_new_org_uses_zero_budget( + self, mock_settings, mock_response + ): + """Test that create_entries uses budget=0 for new org (team doesn't exist).""" + mock_404_response = MagicMock() + mock_404_response.status_code = 404 + mock_404_response.is_success = False + + mock_token_manager = MagicMock() + mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( + return_value={'email': 'test@example.com'} + ) + + mock_client = AsyncMock() + mock_client.get.return_value = mock_404_response + mock_client.get.return_value.raise_for_status.side_effect = ( + httpx.HTTPStatusError( + message='Not Found', request=MagicMock(), response=mock_404_response + ) + ) + mock_client.post.return_value = mock_response + + mock_client_class = MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + with ( + patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}), + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'), + patch('storage.lite_llm_manager.TokenManager', mock_token_manager), + patch('httpx.AsyncClient', mock_client_class), + ): + result = await LiteLlmManager.create_entries( + 'test-org-id', 'test-user-id', mock_settings, create_user=False + ) + + assert result is not None + + # Verify _create_team was called with budget=0 + create_team_call = mock_client.post.call_args_list[0] + assert 'team/new' in create_team_call[0][0] + assert create_team_call[1]['json']['max_budget'] == 0.0 + + # Verify _add_user_to_team was called with budget=0 + add_user_call = mock_client.post.call_args_list[1] + assert 'team/member_add' in add_user_call[0][0] + assert add_user_call[1]['json']['max_budget_in_team'] == 0.0 + + @pytest.mark.asyncio + async def test_create_entries_propagates_non_404_errors(self, mock_settings): + """Test that create_entries propagates non-404 errors from _get_team.""" + mock_500_response = MagicMock() + mock_500_response.status_code = 500 + mock_500_response.is_success = False + + mock_token_manager = MagicMock() + mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock( + return_value={'email': 'test@example.com'} + ) + + mock_client = AsyncMock() + mock_client.get.return_value = mock_500_response + mock_client.get.return_value.raise_for_status.side_effect = ( + httpx.HTTPStatusError( + message='Internal Server Error', + request=MagicMock(), + response=mock_500_response, + ) + ) + + mock_client_class = MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + with ( + patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}), + patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'), + patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'), + patch('storage.lite_llm_manager.TokenManager', mock_token_manager), + patch('httpx.AsyncClient', mock_client_class), + ): + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await LiteLlmManager.create_entries( + 'test-org-id', 'test-user-id', mock_settings, create_user=False + ) + + assert exc_info.value.response.status_code == 500 @pytest.mark.asyncio async def test_migrate_entries_missing_config(self, mock_user_settings): diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index 31f96adfc8..47f7cd109a 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -482,7 +482,7 @@ async def test_get_org_credits_success(mock_litellm_api): spend = 25.0 mock_team_info = { - 'litellm_budget_table': {'max_budget': max_budget}, + 'max_budget_in_team': max_budget, 'spend': spend, }