mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
3 Commits
enhance/pa
...
fix-metric
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a8d6f1f1c | ||
|
|
7b54eb3ab8 | ||
|
|
c3cc4c71f9 |
@@ -428,3 +428,41 @@ class EventStream:
|
||||
break
|
||||
|
||||
return matching_events
|
||||
|
||||
def get_metrics(self):
|
||||
"""Get the accumulated metrics from all events in the stream.
|
||||
|
||||
This method extracts metrics from events that contain them and returns
|
||||
the aggregated metrics object.
|
||||
|
||||
Returns:
|
||||
Metrics: The metrics object containing accumulated cost and token usage data.
|
||||
Returns None if no metrics are found.
|
||||
"""
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
# Look for events with metrics
|
||||
metrics = None
|
||||
events_with_metrics = []
|
||||
|
||||
try:
|
||||
# First collect all events with metrics
|
||||
for event in self.get_events():
|
||||
if hasattr(event, 'llm_metrics') and event.llm_metrics is not None:
|
||||
events_with_metrics.append(event)
|
||||
|
||||
# Then merge them if any were found
|
||||
if events_with_metrics:
|
||||
# Get the first event with metrics to initialize our metrics object
|
||||
first_event = events_with_metrics[0]
|
||||
if first_event.llm_metrics is not None:
|
||||
metrics = Metrics(model_name=first_event.llm_metrics.model_name)
|
||||
|
||||
# Merge metrics from all events
|
||||
for event in events_with_metrics:
|
||||
if event.llm_metrics is not None:
|
||||
metrics.merge(event.llm_metrics)
|
||||
except Exception as e:
|
||||
logger.error(f'Error retrieving metrics from events: {e}')
|
||||
|
||||
return metrics
|
||||
|
||||
@@ -7,6 +7,74 @@ from openhands.runtime.base import Runtime
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}')
|
||||
|
||||
|
||||
@app.get('/metrics')
|
||||
async def get_conversation_metrics(request: Request):
|
||||
"""Retrieve the conversation metrics.
|
||||
|
||||
This endpoint returns the accumulated cost and token usage metrics for the conversation.
|
||||
Metrics are retrieved directly from the runtime's state rather than reconstructing from events,
|
||||
providing a more accurate representation of costs, including those not associated with events.
|
||||
|
||||
Args:
|
||||
request (Request): The incoming FastAPI request object.
|
||||
|
||||
Returns:
|
||||
JSONResponse: A JSON response containing the metrics data.
|
||||
"""
|
||||
try:
|
||||
if not hasattr(request.state, 'conversation'):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'error': 'No conversation found in request state'},
|
||||
)
|
||||
|
||||
conversation = request.state.conversation
|
||||
|
||||
# Get metrics directly from the conversation's runtime state
|
||||
metrics = conversation.get_metrics()
|
||||
|
||||
# If no metrics from state, fall back to event stream metrics for backward compatibility
|
||||
if not metrics and hasattr(conversation.event_stream, 'get_metrics'):
|
||||
metrics = conversation.event_stream.get_metrics()
|
||||
|
||||
if not metrics:
|
||||
# Return empty metrics if not available
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
'accumulated_cost': 0.0,
|
||||
'total_prompt_tokens': 0,
|
||||
'total_completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate total tokens
|
||||
total_prompt_tokens = sum(usage.prompt_tokens for usage in metrics.token_usages)
|
||||
total_completion_tokens = sum(
|
||||
usage.completion_tokens for usage in metrics.token_usages
|
||||
)
|
||||
total_tokens = total_prompt_tokens + total_completion_tokens
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
'accumulated_cost': metrics.accumulated_cost,
|
||||
'total_prompt_tokens': total_prompt_tokens,
|
||||
'total_completion_tokens': total_completion_tokens,
|
||||
'total_tokens': total_tokens,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting conversation metrics: {e}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
'error': f'Error getting conversation metrics: {e}',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get('/config')
|
||||
async def get_remote_runtime_config(request: Request):
|
||||
"""Retrieve the runtime configuration.
|
||||
|
||||
@@ -46,3 +46,24 @@ class Conversation:
|
||||
if self.event_stream:
|
||||
self.event_stream.close()
|
||||
asyncio.create_task(call_sync_from_async(self.runtime.close))
|
||||
|
||||
def get_metrics(self):
|
||||
"""Get metrics directly from the runtime's state.
|
||||
|
||||
This method retrieves metrics from the runtime's state object rather than
|
||||
reconstructing them from events, providing a more accurate representation
|
||||
of costs and token usage, including those not associated with events.
|
||||
|
||||
Returns:
|
||||
Metrics: The metrics object containing accumulated cost and token usage data.
|
||||
Returns None if no metrics are available or if the runtime has no state.
|
||||
"""
|
||||
try:
|
||||
if hasattr(self.runtime, 'state') and self.runtime.state:
|
||||
return self.runtime.state.metrics
|
||||
return None
|
||||
except Exception as e:
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
logger.error(f'Error retrieving metrics from runtime state: {e}')
|
||||
return None
|
||||
|
||||
129
tests/unit/events/test_stream.py
Normal file
129
tests/unit/events/test_stream.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.metrics import Metrics
|
||||
|
||||
|
||||
class TestEventStream:
|
||||
def test_get_metrics_empty_stream(self):
|
||||
"""Test that get_metrics returns None for an empty stream."""
|
||||
sid = 'test-stream-id'
|
||||
file_store = MagicMock()
|
||||
stream = EventStream(sid=sid, file_store=file_store)
|
||||
assert stream.get_metrics() is None
|
||||
|
||||
def test_get_metrics_no_metrics_in_events(self):
|
||||
"""Test that get_metrics returns None when no events have metrics."""
|
||||
sid = 'test-stream-id'
|
||||
file_store = MagicMock()
|
||||
stream = EventStream(sid=sid, file_store=file_store)
|
||||
event = MagicMock(spec=Event)
|
||||
event.llm_metrics = None
|
||||
|
||||
with patch.object(stream, 'get_events', return_value=[event]):
|
||||
assert stream.get_metrics() is None
|
||||
|
||||
def test_get_metrics_with_metrics(self):
|
||||
"""Test that get_metrics correctly aggregates metrics from events."""
|
||||
sid = 'test-stream-id'
|
||||
file_store = MagicMock()
|
||||
stream = EventStream(sid=sid, file_store=file_store)
|
||||
|
||||
# Create mock events with metrics
|
||||
event1 = MagicMock(spec=Event)
|
||||
metrics1 = Metrics(model_name='gpt-4')
|
||||
metrics1.add_token_usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp1',
|
||||
)
|
||||
event1.llm_metrics = metrics1
|
||||
|
||||
event2 = MagicMock(spec=Event)
|
||||
metrics2 = Metrics(model_name='gpt-4')
|
||||
metrics2.add_token_usage(
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp2',
|
||||
)
|
||||
event2.llm_metrics = metrics2
|
||||
|
||||
with patch.object(stream, 'get_events', return_value=[event1, event2]):
|
||||
result = stream.get_metrics()
|
||||
|
||||
assert result is not None
|
||||
assert result.model_name == 'gpt-4'
|
||||
# Check token usages are merged correctly
|
||||
total_prompt_tokens = sum(
|
||||
usage.prompt_tokens for usage in result.token_usages
|
||||
)
|
||||
total_completion_tokens = sum(
|
||||
usage.completion_tokens for usage in result.token_usages
|
||||
)
|
||||
assert total_prompt_tokens == 25 # 10 + 15
|
||||
assert total_completion_tokens == 45 # 20 + 25
|
||||
assert len(result.token_usages) == 2
|
||||
|
||||
def test_get_metrics_with_exception(self):
|
||||
"""Test that get_metrics handles exceptions gracefully."""
|
||||
sid = 'test-stream-id'
|
||||
file_store = MagicMock()
|
||||
stream = EventStream(sid=sid, file_store=file_store)
|
||||
|
||||
with patch.object(
|
||||
stream, 'get_events', side_effect=Exception('Test exception')
|
||||
):
|
||||
assert stream.get_metrics() is None
|
||||
|
||||
def test_get_metrics_with_mixed_events(self):
|
||||
"""Test that get_metrics correctly handles a mix of events with and without metrics."""
|
||||
sid = 'test-stream-id'
|
||||
file_store = MagicMock()
|
||||
stream = EventStream(sid=sid, file_store=file_store)
|
||||
|
||||
# Create mock events, some with metrics and some without
|
||||
event1 = MagicMock(spec=Event)
|
||||
metrics1 = Metrics(model_name='gpt-4')
|
||||
metrics1.add_token_usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp1',
|
||||
)
|
||||
event1.llm_metrics = metrics1
|
||||
|
||||
event2 = MagicMock(spec=Event)
|
||||
event2.llm_metrics = None
|
||||
|
||||
event3 = MagicMock(spec=Event)
|
||||
metrics3 = Metrics(model_name='gpt-4')
|
||||
metrics3.add_token_usage(
|
||||
prompt_tokens=15,
|
||||
completion_tokens=25,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp3',
|
||||
)
|
||||
event3.llm_metrics = metrics3
|
||||
|
||||
with patch.object(stream, 'get_events', return_value=[event1, event2, event3]):
|
||||
result = stream.get_metrics()
|
||||
|
||||
assert result is not None
|
||||
assert result.model_name == 'gpt-4'
|
||||
# Check token usages are merged correctly
|
||||
total_prompt_tokens = sum(
|
||||
usage.prompt_tokens for usage in result.token_usages
|
||||
)
|
||||
total_completion_tokens = sum(
|
||||
usage.completion_tokens for usage in result.token_usages
|
||||
)
|
||||
assert total_prompt_tokens == 25 # 10 + 15
|
||||
assert total_completion_tokens == 45 # 20 + 25
|
||||
assert len(result.token_usages) == 2
|
||||
203
tests/unit/server/routes/test_conversation_metrics.py
Normal file
203
tests/unit/server/routes/test_conversation_metrics.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.server.routes.conversation import get_conversation_metrics
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request with a conversation."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.state.conversation = MagicMock()
|
||||
request.state.conversation.runtime = MagicMock()
|
||||
request.state.conversation.event_stream = MagicMock()
|
||||
# Add get_metrics method to conversation mock
|
||||
request.state.conversation.get_metrics = MagicMock()
|
||||
return request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_success(mock_request):
|
||||
"""Test successful retrieval of conversation metrics from runtime state."""
|
||||
# Setup mock metrics
|
||||
metrics = Metrics()
|
||||
metrics.token_usages = [
|
||||
TokenUsage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
model='test-model',
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='test-response-1',
|
||||
),
|
||||
TokenUsage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=150,
|
||||
model='test-model',
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='test-response-2',
|
||||
),
|
||||
]
|
||||
metrics.accumulated_cost = 0.25
|
||||
|
||||
# Configure mock to return metrics from runtime state
|
||||
mock_request.state.conversation.get_metrics.return_value = metrics
|
||||
|
||||
# Call the endpoint
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Verify response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Extract content from JSONResponse
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
content_dict = json.loads(content)
|
||||
|
||||
# Verify metrics
|
||||
assert content_dict['accumulated_cost'] == 0.25
|
||||
assert content_dict['total_prompt_tokens'] == 300
|
||||
assert content_dict['total_completion_tokens'] == 200
|
||||
assert content_dict['total_tokens'] == 500
|
||||
|
||||
# Verify get_metrics was called on the conversation
|
||||
mock_request.state.conversation.get_metrics.assert_called_once()
|
||||
# Verify event_stream.get_metrics was not called
|
||||
mock_request.state.conversation.event_stream.get_metrics.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_fallback_to_event_stream(mock_request):
|
||||
"""Test fallback to event_stream metrics when runtime state metrics are not available."""
|
||||
# Setup mock metrics
|
||||
metrics = Metrics()
|
||||
metrics.token_usages = [
|
||||
TokenUsage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
model='test-model',
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='test-response-1',
|
||||
),
|
||||
]
|
||||
metrics.accumulated_cost = 0.15
|
||||
|
||||
# Configure mock to return None from runtime state but metrics from event_stream
|
||||
mock_request.state.conversation.get_metrics.return_value = None
|
||||
mock_request.state.conversation.event_stream.get_metrics.return_value = metrics
|
||||
|
||||
# Call the endpoint
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Verify response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Extract content from JSONResponse
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
content_dict = json.loads(content)
|
||||
|
||||
# Verify metrics from event_stream were used
|
||||
assert content_dict['accumulated_cost'] == 0.15
|
||||
assert content_dict['total_prompt_tokens'] == 100
|
||||
assert content_dict['total_completion_tokens'] == 50
|
||||
assert content_dict['total_tokens'] == 150
|
||||
|
||||
# Verify both methods were called
|
||||
mock_request.state.conversation.get_metrics.assert_called_once()
|
||||
mock_request.state.conversation.event_stream.get_metrics.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_no_metrics(mock_request):
|
||||
"""Test handling when no metrics are available from either source."""
|
||||
# Configure mocks to return None for metrics
|
||||
mock_request.state.conversation.get_metrics.return_value = None
|
||||
mock_request.state.conversation.event_stream.get_metrics.return_value = None
|
||||
|
||||
# Call the endpoint
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Verify response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Extract content from JSONResponse
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
content_dict = json.loads(content)
|
||||
|
||||
# Verify default metrics
|
||||
assert content_dict['accumulated_cost'] == 0.0
|
||||
assert content_dict['total_prompt_tokens'] == 0
|
||||
assert content_dict['total_completion_tokens'] == 0
|
||||
assert content_dict['total_tokens'] == 0
|
||||
|
||||
# Verify both methods were called
|
||||
mock_request.state.conversation.get_metrics.assert_called_once()
|
||||
mock_request.state.conversation.event_stream.get_metrics.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_no_conversation():
|
||||
"""Test handling when no conversation is found in request state."""
|
||||
# Create a request without conversation
|
||||
request = MagicMock(spec=Request)
|
||||
request.state = MagicMock()
|
||||
|
||||
# Remove conversation attribute
|
||||
delattr(request.state, 'conversation')
|
||||
|
||||
# Call the endpoint
|
||||
response = await get_conversation_metrics(request)
|
||||
|
||||
# Verify response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
# Extract content from JSONResponse
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
content_dict = json.loads(content)
|
||||
|
||||
# Verify error message
|
||||
assert 'error' in content_dict
|
||||
assert 'No conversation found in request state' in content_dict['error']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_exception(mock_request):
|
||||
"""Test handling when an exception occurs."""
|
||||
# Configure mock to raise an exception
|
||||
mock_request.state.conversation.get_metrics.side_effect = Exception(
|
||||
'Test exception'
|
||||
)
|
||||
|
||||
# Call the endpoint
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Verify response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
# Extract content from JSONResponse
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
content_dict = json.loads(content)
|
||||
|
||||
# Verify error message
|
||||
assert 'error' in content_dict
|
||||
assert 'Error getting conversation metrics' in content_dict['error']
|
||||
209
tests/unit/test_conversation_metrics_api.py
Normal file
209
tests/unit/test_conversation_metrics_api.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
from openhands.server.routes.conversation import get_conversation_metrics
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metrics():
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 0.25
|
||||
metrics.token_usages = [
|
||||
TokenUsage(
|
||||
model='gpt-4',
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp1',
|
||||
),
|
||||
TokenUsage(
|
||||
model='gpt-4',
|
||||
prompt_tokens=200,
|
||||
completion_tokens=75,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='resp2',
|
||||
),
|
||||
]
|
||||
return metrics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_success(mock_metrics):
|
||||
"""Test that the metrics endpoint returns the correct metrics data from runtime state."""
|
||||
|
||||
# Create a mock request with a conversation that has metrics in runtime state
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.get_metrics = MagicMock(return_value=mock_metrics)
|
||||
mock_conversation.event_stream = MagicMock()
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.conversation = mock_conversation
|
||||
|
||||
# Call the endpoint function directly
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Check the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Extract the content from the response
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
data = json.loads(content)
|
||||
|
||||
# Verify the metrics data
|
||||
assert data['accumulated_cost'] == 0.25
|
||||
assert data['total_prompt_tokens'] == 300
|
||||
assert data['total_completion_tokens'] == 125
|
||||
assert data['total_tokens'] == 425
|
||||
|
||||
# Verify the get_metrics method was called on conversation
|
||||
mock_conversation.get_metrics.assert_called_once()
|
||||
# Verify event_stream.get_metrics was not called
|
||||
mock_conversation.event_stream.get_metrics.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_fallback_to_event_stream(mock_metrics):
|
||||
"""Test fallback to event_stream metrics when runtime state metrics are not available."""
|
||||
|
||||
# Create a mock request with a conversation that has metrics in event_stream but not in runtime state
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.get_metrics = MagicMock(return_value=None)
|
||||
mock_conversation.event_stream = MagicMock()
|
||||
mock_conversation.event_stream.get_metrics = MagicMock(return_value=mock_metrics)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.conversation = mock_conversation
|
||||
|
||||
# Call the endpoint function directly
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Check the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Extract the content from the response
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
data = json.loads(content)
|
||||
|
||||
# Verify the metrics data
|
||||
assert data['accumulated_cost'] == 0.25
|
||||
assert data['total_prompt_tokens'] == 300
|
||||
assert data['total_completion_tokens'] == 125
|
||||
assert data['total_tokens'] == 425
|
||||
|
||||
# Verify both methods were called
|
||||
mock_conversation.get_metrics.assert_called_once()
|
||||
mock_conversation.event_stream.get_metrics.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_no_metrics():
|
||||
"""Test that the metrics endpoint handles the case where no metrics are available from either source."""
|
||||
|
||||
# Create a mock request with a conversation that has no metrics
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.get_metrics = MagicMock(return_value=None)
|
||||
mock_conversation.event_stream = MagicMock()
|
||||
mock_conversation.event_stream.get_metrics = MagicMock(return_value=None)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.conversation = mock_conversation
|
||||
|
||||
# Call the endpoint function directly
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Check the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Extract the content from the response
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
data = json.loads(content)
|
||||
|
||||
# Verify the metrics data
|
||||
assert data['accumulated_cost'] == 0.0
|
||||
assert data['total_prompt_tokens'] == 0
|
||||
assert data['total_completion_tokens'] == 0
|
||||
assert data['total_tokens'] == 0
|
||||
|
||||
# Verify both methods were called
|
||||
mock_conversation.get_metrics.assert_called_once()
|
||||
mock_conversation.event_stream.get_metrics.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_no_conversation():
|
||||
"""Test that the metrics endpoint handles the case where no conversation is found."""
|
||||
|
||||
# Create a mock request with no conversation attribute
|
||||
mock_request = MagicMock()
|
||||
mock_request.state = MagicMock()
|
||||
# Intentionally not setting request.state.conversation
|
||||
|
||||
# Call the endpoint function directly
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Check the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
# Extract the content from the response
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
data = json.loads(content)
|
||||
|
||||
# Verify the error message
|
||||
assert 'error' in data
|
||||
assert (
|
||||
'No conversation found' in data['error']
|
||||
or 'conversation' in data['error'].lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metrics_exception():
|
||||
"""Test that the metrics endpoint handles exceptions gracefully."""
|
||||
|
||||
# Create a mock request with a conversation that raises an exception
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.get_metrics.side_effect = Exception('Test exception')
|
||||
mock_conversation.event_stream = MagicMock()
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.conversation = mock_conversation
|
||||
|
||||
# Call the endpoint function directly
|
||||
response = await get_conversation_metrics(mock_request)
|
||||
|
||||
# Check the response
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
# Extract the content from the response
|
||||
content = response.body.decode('utf-8')
|
||||
import json
|
||||
|
||||
data = json.loads(content)
|
||||
|
||||
# Verify the error message
|
||||
assert 'error' in data
|
||||
assert 'Error getting conversation metrics' in data['error']
|
||||
|
||||
# Verify the get_metrics method was called
|
||||
mock_conversation.get_metrics.assert_called_once()
|
||||
# Event stream get_metrics should not be called since the first call raised an exception
|
||||
mock_conversation.event_stream.get_metrics.assert_not_called()
|
||||
Reference in New Issue
Block a user