mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
fix(backend): invalid api key (#12217)
This commit is contained in:
@@ -38,6 +38,8 @@ LITE_LLM_API_URL = os.environ.get(
|
|||||||
)
|
)
|
||||||
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
|
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
|
||||||
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
|
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
|
||||||
|
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||||
|
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
|
||||||
SUBSCRIPTION_PRICE_DATA = {
|
SUBSCRIPTION_PRICE_DATA = {
|
||||||
'MONTHLY_SUBSCRIPTION': {
|
'MONTHLY_SUBSCRIPTION': {
|
||||||
'unit_amount': 2000,
|
'unit_amount': 2000,
|
||||||
|
|||||||
@@ -4,7 +4,11 @@ import httpx
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
from server.config import get_config
|
from server.config import get_config
|
||||||
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
|
from server.constants import (
|
||||||
|
BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||||
|
LITE_LLM_API_KEY,
|
||||||
|
LITE_LLM_API_URL,
|
||||||
|
)
|
||||||
from storage.api_key_store import ApiKeyStore
|
from storage.api_key_store import ApiKeyStore
|
||||||
from storage.database import session_maker
|
from storage.database import session_maker
|
||||||
from storage.saas_settings_store import SaasSettingsStore
|
from storage.saas_settings_store import SaasSettingsStore
|
||||||
@@ -112,6 +116,70 @@ async def generate_byor_key(user_id: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
|
||||||
|
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byor_key: The BYOR key to verify
|
||||||
|
user_id: The user ID for logging purposes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||||
|
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||||
|
"""
|
||||||
|
if not (LITE_LLM_API_URL and byor_key):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
verify=httpx_verify_option(),
|
||||||
|
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||||
|
) as client:
|
||||||
|
# Make a lightweight request to verify the key
|
||||||
|
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||||
|
response = await client.get(
|
||||||
|
f'{LITE_LLM_API_URL}/v1/models',
|
||||||
|
headers={
|
||||||
|
'Authorization': f'Bearer {byor_key}',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only 200 status code indicates valid key
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.debug(
|
||||||
|
'BYOR key verification successful',
|
||||||
|
extra={'user_id': user_id},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||||
|
# This includes authentication errors and server errors
|
||||||
|
logger.warning(
|
||||||
|
'BYOR key verification failed - treating as invalid',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'status_code': response.status_code,
|
||||||
|
'key_prefix': byor_key[:10] + '...'
|
||||||
|
if len(byor_key) > 10
|
||||||
|
else byor_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except (httpx.TimeoutException, Exception) as e:
|
||||||
|
# Any exception (timeout, network error, etc.) means we can't verify
|
||||||
|
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||||
|
logger.warning(
|
||||||
|
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'error': str(e),
|
||||||
|
'error_type': type(e).__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||||
@@ -278,18 +346,44 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
|||||||
|
|
||||||
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
|
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
|
||||||
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||||
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
|
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||||
|
|
||||||
|
This endpoint validates that the key exists in LiteLLM before returning it.
|
||||||
|
If validation fails, it automatically generates a new key to ensure users
|
||||||
|
always receive a working key.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Check if the BYOR key exists in the database
|
# Check if the BYOR key exists in the database
|
||||||
byor_key = await get_byor_key_from_db(user_id)
|
byor_key = await get_byor_key_from_db(user_id)
|
||||||
if byor_key:
|
if byor_key:
|
||||||
return {'key': byor_key}
|
# Validate that the key is actually registered in LiteLLM
|
||||||
|
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
if is_valid:
|
||||||
|
return {'key': byor_key}
|
||||||
|
else:
|
||||||
|
# Key exists in DB but is invalid in LiteLLM - regenerate it
|
||||||
|
logger.warning(
|
||||||
|
'BYOR key found in database but invalid in LiteLLM - regenerating',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'key_prefix': byor_key[:10] + '...'
|
||||||
|
if len(byor_key) > 10
|
||||||
|
else byor_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Delete the invalid key from LiteLLM (best effort, don't fail if it doesn't exist)
|
||||||
|
await delete_byor_key_from_litellm(user_id, byor_key)
|
||||||
|
# Fall through to generate a new key
|
||||||
|
|
||||||
# If not, generate a new key for BYOR
|
# Generate a new key for BYOR (either no key exists or validation failed)
|
||||||
key = await generate_byor_key(user_id)
|
key = await generate_byor_key(user_id)
|
||||||
if key:
|
if key:
|
||||||
# Store the key in the database
|
# Store the key in the database
|
||||||
await store_byor_key_in_db(user_id, key)
|
await store_byor_key_in_db(user_id, key)
|
||||||
|
logger.info(
|
||||||
|
'Successfully generated and stored new BYOR key',
|
||||||
|
extra={'user_id': user_id},
|
||||||
|
)
|
||||||
return {'key': key}
|
return {'key': key}
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -301,6 +395,9 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
|||||||
detail='Failed to generate new BYOR LLM API key',
|
detail='Failed to generate new BYOR LLM API key',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
# Re-raise HTTP exceptions as-is
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception('Error retrieving BYOR LLM API key', extra={'error': str(e)})
|
logger.exception('Error retrieving BYOR LLM API key', extra={'error': str(e)})
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
330
enterprise/tests/unit/server/routes/test_api_keys.py
Normal file
330
enterprise/tests/unit/server/routes/test_api_keys.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
"""Unit tests for API keys routes, focusing on BYOR key validation and retrieval."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from server.routes.api_keys import (
|
||||||
|
get_llm_api_key_for_byor,
|
||||||
|
verify_byor_key_in_litellm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerifyByorKeyInLitellm:
|
||||||
|
"""Test the verify_byor_key_in_litellm function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||||
|
async def test_verify_valid_key_returns_true(self, mock_client_class):
|
||||||
|
"""Test that a valid key (200 response) returns True."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-valid-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.is_success = True
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__.return_value = mock_client
|
||||||
|
mock_client.__aexit__.return_value = None
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is True
|
||||||
|
mock_client.get.assert_called_once_with(
|
||||||
|
'https://litellm.example.com/v1/models',
|
||||||
|
headers={'Authorization': f'Bearer {byor_key}'},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||||
|
async def test_verify_invalid_key_401_returns_false(self, mock_client_class):
|
||||||
|
"""Test that an invalid key (401 response) returns False."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-invalid-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__.return_value = mock_client
|
||||||
|
mock_client.__aexit__.return_value = None
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||||
|
async def test_verify_invalid_key_403_returns_false(self, mock_client_class):
|
||||||
|
"""Test that an invalid key (403 response) returns False."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-forbidden-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 403
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__.return_value = mock_client
|
||||||
|
mock_client.__aexit__.return_value = None
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||||
|
async def test_verify_server_error_returns_false(self, mock_client_class):
|
||||||
|
"""Test that a server error (500) returns False to ensure key validity."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
mock_response.is_success = False
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__.return_value = mock_client
|
||||||
|
mock_client.__aexit__.return_value = None
|
||||||
|
mock_client.get.return_value = mock_response
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||||
|
async def test_verify_timeout_returns_false(self, mock_client_class):
|
||||||
|
"""Test that a timeout returns False to ensure key validity."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__.return_value = mock_client
|
||||||
|
mock_client.__aexit__.return_value = None
|
||||||
|
mock_client.get.side_effect = httpx.TimeoutException('Request timed out')
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||||
|
async def test_verify_network_error_returns_false(self, mock_client_class):
|
||||||
|
"""Test that a network error returns False to ensure key validity."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.__aenter__.return_value = mock_client
|
||||||
|
mock_client.__aexit__.return_value = None
|
||||||
|
mock_client.get.side_effect = httpx.NetworkError('Network error')
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', None)
|
||||||
|
async def test_verify_missing_api_url_returns_false(self):
|
||||||
|
"""Test that missing LITE_LLM_API_URL returns False."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = 'sk-key-123'
|
||||||
|
user_id = 'user-123'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||||
|
async def test_verify_empty_key_returns_false(self):
|
||||||
|
"""Test that empty key returns False."""
|
||||||
|
# Arrange
|
||||||
|
byor_key = ''
|
||||||
|
user_id = 'user-123'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLlmApiKeyForByor:
|
||||||
|
"""Test the get_llm_api_key_for_byor endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||||
|
@patch('server.routes.api_keys.generate_byor_key')
|
||||||
|
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||||
|
async def test_no_key_in_database_generates_new(
|
||||||
|
self, mock_get_key, mock_generate_key, mock_store_key
|
||||||
|
):
|
||||||
|
"""Test that when no key exists in database, a new one is generated."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'user-123'
|
||||||
|
new_key = 'sk-new-generated-key'
|
||||||
|
mock_get_key.return_value = None
|
||||||
|
mock_generate_key.return_value = new_key
|
||||||
|
mock_store_key.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == {'key': new_key}
|
||||||
|
mock_get_key.assert_called_once_with(user_id)
|
||||||
|
mock_generate_key.assert_called_once_with(user_id)
|
||||||
|
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||||
|
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||||
|
async def test_valid_key_in_database_returns_key(
|
||||||
|
self, mock_get_key, mock_verify_key
|
||||||
|
):
|
||||||
|
"""Test that when a valid key exists in database, it is returned."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'user-123'
|
||||||
|
existing_key = 'sk-existing-valid-key'
|
||||||
|
mock_get_key.return_value = existing_key
|
||||||
|
mock_verify_key.return_value = True
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == {'key': existing_key}
|
||||||
|
mock_get_key.assert_called_once_with(user_id)
|
||||||
|
mock_verify_key.assert_called_once_with(existing_key, user_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||||
|
@patch('server.routes.api_keys.generate_byor_key')
|
||||||
|
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||||
|
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||||
|
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||||
|
async def test_invalid_key_in_database_regenerates(
|
||||||
|
self,
|
||||||
|
mock_get_key,
|
||||||
|
mock_verify_key,
|
||||||
|
mock_delete_key,
|
||||||
|
mock_generate_key,
|
||||||
|
mock_store_key,
|
||||||
|
):
|
||||||
|
"""Test that when an invalid key exists in database, it is regenerated."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'user-123'
|
||||||
|
invalid_key = 'sk-invalid-key'
|
||||||
|
new_key = 'sk-new-generated-key'
|
||||||
|
mock_get_key.return_value = invalid_key
|
||||||
|
mock_verify_key.return_value = False
|
||||||
|
mock_delete_key.return_value = True
|
||||||
|
mock_generate_key.return_value = new_key
|
||||||
|
mock_store_key.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == {'key': new_key}
|
||||||
|
mock_get_key.assert_called_once_with(user_id)
|
||||||
|
mock_verify_key.assert_called_once_with(invalid_key, user_id)
|
||||||
|
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||||
|
mock_generate_key.assert_called_once_with(user_id)
|
||||||
|
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||||
|
@patch('server.routes.api_keys.generate_byor_key')
|
||||||
|
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||||
|
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||||
|
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||||
|
async def test_invalid_key_deletion_failure_still_regenerates(
|
||||||
|
self,
|
||||||
|
mock_get_key,
|
||||||
|
mock_verify_key,
|
||||||
|
mock_delete_key,
|
||||||
|
mock_generate_key,
|
||||||
|
mock_store_key,
|
||||||
|
):
|
||||||
|
"""Test that even if deletion fails, regeneration still proceeds."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'user-123'
|
||||||
|
invalid_key = 'sk-invalid-key'
|
||||||
|
new_key = 'sk-new-generated-key'
|
||||||
|
mock_get_key.return_value = invalid_key
|
||||||
|
mock_verify_key.return_value = False
|
||||||
|
mock_delete_key.return_value = False # Deletion fails
|
||||||
|
mock_generate_key.return_value = new_key
|
||||||
|
mock_store_key.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == {'key': new_key}
|
||||||
|
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||||
|
mock_generate_key.assert_called_once_with(user_id)
|
||||||
|
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.generate_byor_key')
|
||||||
|
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||||
|
async def test_key_generation_failure_raises_exception(
|
||||||
|
self, mock_get_key, mock_generate_key
|
||||||
|
):
|
||||||
|
"""Test that when key generation fails, an HTTPException is raised."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_get_key.return_value = None
|
||||||
|
mock_generate_key.return_value = None
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await get_llm_api_key_for_byor(user_id=user_id)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 500
|
||||||
|
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||||
|
async def test_database_error_raises_exception(self, mock_get_key):
|
||||||
|
"""Test that database errors are properly handled."""
|
||||||
|
# Arrange
|
||||||
|
user_id = 'user-123'
|
||||||
|
mock_get_key.side_effect = Exception('Database connection error')
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await get_llm_api_key_for_byor(user_id=user_id)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 500
|
||||||
|
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||||
Reference in New Issue
Block a user