Compare commits

...

2 Commits

Author SHA1 Message Date
openhands
7b54eb3ab8 Add comprehensive tests for conversation metrics API 2025-03-04 21:40:48 +00:00
openhands
c3cc4c71f9 Fix unreachable code in get_metrics method 2025-03-04 21:38:06 +00:00
5 changed files with 525 additions and 0 deletions

View File

@@ -428,3 +428,41 @@ class EventStream:
break
return matching_events
def get_metrics(self):
"""Get the accumulated metrics from all events in the stream.
This method extracts metrics from events that contain them and returns
the aggregated metrics object.
Returns:
Metrics: The metrics object containing accumulated cost and token usage data.
Returns None if no metrics are found.
"""
from openhands.llm.metrics import Metrics
# Look for events with metrics
metrics = None
events_with_metrics = []
try:
# First collect all events with metrics
for event in self.get_events():
if hasattr(event, 'llm_metrics') and event.llm_metrics is not None:
events_with_metrics.append(event)
# Then merge them if any were found
if events_with_metrics:
# Get the first event with metrics to initialize our metrics object
first_event = events_with_metrics[0]
if first_event.llm_metrics is not None:
metrics = Metrics(model_name=first_event.llm_metrics.model_name)
# Merge metrics from all events
for event in events_with_metrics:
if event.llm_metrics is not None:
metrics.merge(event.llm_metrics)
except Exception as e:
logger.error(f'Error retrieving metrics from events: {e}')
return metrics

View File

@@ -7,6 +7,70 @@ from openhands.runtime.base import Runtime
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@app.get('/metrics')
async def get_conversation_metrics(request: Request):
"""Retrieve the conversation metrics.
This endpoint returns the accumulated cost and token usage metrics for the conversation.
Args:
request (Request): The incoming FastAPI request object.
Returns:
JSONResponse: A JSON response containing the metrics data.
"""
try:
if not hasattr(request.state, 'conversation'):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': 'No conversation found in request state'},
)
event_stream = request.state.conversation.event_stream
# Get metrics from the event stream
metrics = (
event_stream.get_metrics() if hasattr(event_stream, 'get_metrics') else None
)
if not metrics:
# Return empty metrics if not available
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
'accumulated_cost': 0.0,
'total_prompt_tokens': 0,
'total_completion_tokens': 0,
'total_tokens': 0,
},
)
# Calculate total tokens
total_prompt_tokens = sum(usage.prompt_tokens for usage in metrics.token_usages)
total_completion_tokens = sum(
usage.completion_tokens for usage in metrics.token_usages
)
total_tokens = total_prompt_tokens + total_completion_tokens
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
'accumulated_cost': metrics.accumulated_cost,
'total_prompt_tokens': total_prompt_tokens,
'total_completion_tokens': total_completion_tokens,
'total_tokens': total_tokens,
},
)
except Exception as e:
logger.error(f'Error getting conversation metrics: {e}')
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'error': f'Error getting conversation metrics: {e}',
},
)
@app.get('/config')
async def get_remote_runtime_config(request: Request):
"""Retrieve the runtime configuration.

View File

@@ -0,0 +1,120 @@
import pytest
from unittest.mock import MagicMock, patch
from openhands.events.event import Event
from openhands.events.stream import EventStream
from openhands.llm.metrics import Metrics
class TestEventStream:
def test_get_metrics_empty_stream(self):
"""Test that get_metrics returns None for an empty stream."""
sid = "test-stream-id"
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
assert stream.get_metrics() is None
def test_get_metrics_no_metrics_in_events(self):
"""Test that get_metrics returns None when no events have metrics."""
sid = "test-stream-id"
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
event = MagicMock(spec=Event)
event.llm_metrics = None
with patch.object(stream, 'get_events', return_value=[event]):
assert stream.get_metrics() is None
def test_get_metrics_with_metrics(self):
"""Test that get_metrics correctly aggregates metrics from events."""
sid = "test-stream-id"
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
# Create mock events with metrics
event1 = MagicMock(spec=Event)
metrics1 = Metrics(model_name="gpt-4")
metrics1.add_token_usage(
prompt_tokens=10,
completion_tokens=20,
cache_read_tokens=0,
cache_write_tokens=0,
response_id="resp1"
)
event1.llm_metrics = metrics1
event2 = MagicMock(spec=Event)
metrics2 = Metrics(model_name="gpt-4")
metrics2.add_token_usage(
prompt_tokens=15,
completion_tokens=25,
cache_read_tokens=0,
cache_write_tokens=0,
response_id="resp2"
)
event2.llm_metrics = metrics2
with patch.object(stream, 'get_events', return_value=[event1, event2]):
result = stream.get_metrics()
assert result is not None
assert result.model_name == "gpt-4"
# Check token usages are merged correctly
total_prompt_tokens = sum(usage.prompt_tokens for usage in result.token_usages)
total_completion_tokens = sum(usage.completion_tokens for usage in result.token_usages)
assert total_prompt_tokens == 25 # 10 + 15
assert total_completion_tokens == 45 # 20 + 25
assert len(result.token_usages) == 2
def test_get_metrics_with_exception(self):
"""Test that get_metrics handles exceptions gracefully."""
sid = "test-stream-id"
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
with patch.object(stream, 'get_events', side_effect=Exception("Test exception")):
assert stream.get_metrics() is None
def test_get_metrics_with_mixed_events(self):
"""Test that get_metrics correctly handles a mix of events with and without metrics."""
sid = "test-stream-id"
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
# Create mock events, some with metrics and some without
event1 = MagicMock(spec=Event)
metrics1 = Metrics(model_name="gpt-4")
metrics1.add_token_usage(
prompt_tokens=10,
completion_tokens=20,
cache_read_tokens=0,
cache_write_tokens=0,
response_id="resp1"
)
event1.llm_metrics = metrics1
event2 = MagicMock(spec=Event)
event2.llm_metrics = None
event3 = MagicMock(spec=Event)
metrics3 = Metrics(model_name="gpt-4")
metrics3.add_token_usage(
prompt_tokens=15,
completion_tokens=25,
cache_read_tokens=0,
cache_write_tokens=0,
response_id="resp3"
)
event3.llm_metrics = metrics3
with patch.object(stream, 'get_events', return_value=[event1, event2, event3]):
result = stream.get_metrics()
assert result is not None
assert result.model_name == "gpt-4"
# Check token usages are merged correctly
total_prompt_tokens = sum(usage.prompt_tokens for usage in result.token_usages)
total_completion_tokens = sum(usage.completion_tokens for usage in result.token_usages)
assert total_prompt_tokens == 25 # 10 + 15
assert total_completion_tokens == 45 # 20 + 25
assert len(result.token_usages) == 2

View File

@@ -0,0 +1,139 @@
import pytest
from unittest.mock import MagicMock, patch
from fastapi import Request, status
from fastapi.responses import JSONResponse
from openhands.server.routes.conversation import get_conversation_metrics
from openhands.llm.metrics import Metrics, TokenUsage
@pytest.fixture
def mock_request():
"""Create a mock request with a conversation."""
request = MagicMock(spec=Request)
request.state.conversation = MagicMock()
request.state.conversation.runtime = MagicMock()
request.state.conversation.event_stream = MagicMock()
return request
@pytest.mark.asyncio
async def test_get_conversation_metrics_success(mock_request):
"""Test successful retrieval of conversation metrics."""
# Setup mock metrics
metrics = Metrics()
metrics.token_usages = [
TokenUsage(
prompt_tokens=100,
completion_tokens=50,
model="test-model",
cache_read_tokens=0,
cache_write_tokens=0,
response_id="test-response-1"
),
TokenUsage(
prompt_tokens=200,
completion_tokens=150,
model="test-model",
cache_read_tokens=0,
cache_write_tokens=0,
response_id="test-response-2"
),
]
metrics.accumulated_cost = 0.25
# Configure mock to return metrics
mock_request.state.conversation.event_stream.get_metrics.return_value = metrics
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify metrics
assert content_dict['accumulated_cost'] == 0.25
assert content_dict['total_prompt_tokens'] == 300
assert content_dict['total_completion_tokens'] == 200
assert content_dict['total_tokens'] == 500
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_metrics(mock_request):
"""Test handling when no metrics are available."""
# Configure mock to return None for metrics
mock_request.state.conversation.event_stream.get_metrics.return_value = None
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify default metrics
assert content_dict['accumulated_cost'] == 0.0
assert content_dict['total_prompt_tokens'] == 0
assert content_dict['total_completion_tokens'] == 0
assert content_dict['total_tokens'] == 0
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_conversation():
"""Test handling when no conversation is found in request state."""
# Create a request without conversation
request = MagicMock(spec=Request)
request.state = MagicMock()
# Remove conversation attribute
delattr(request.state, 'conversation')
# Call the endpoint
response = await get_conversation_metrics(request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify error message
assert 'error' in content_dict
assert 'No conversation found in request state' in content_dict['error']
@pytest.mark.asyncio
async def test_get_conversation_metrics_exception(mock_request):
"""Test handling when an exception occurs."""
# Configure mock to raise an exception
mock_request.state.conversation.event_stream.get_metrics.side_effect = Exception("Test exception")
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify error message
assert 'error' in content_dict
assert 'Test exception' in content_dict['error']

View File

@@ -0,0 +1,164 @@
import pytest
from unittest.mock import MagicMock, patch
from fastapi import Request, status
from fastapi.responses import JSONResponse
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.server.routes.conversation import get_conversation_metrics
@pytest.fixture
def mock_metrics():
metrics = Metrics()
metrics.accumulated_cost = 0.25
metrics.token_usages = [
TokenUsage(
model="gpt-4",
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
response_id="resp1"
),
TokenUsage(
model="gpt-4",
prompt_tokens=200,
completion_tokens=75,
cache_read_tokens=0,
cache_write_tokens=0,
response_id="resp2"
),
]
return metrics
@pytest.mark.asyncio
async def test_get_conversation_metrics_success(mock_metrics):
"""Test that the metrics endpoint returns the correct metrics data."""
# Create a mock request with a conversation that has metrics
mock_event_stream = MagicMock()
mock_event_stream.get_metrics.return_value = mock_metrics
mock_conversation = MagicMock()
mock_conversation.event_stream = mock_event_stream
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the metrics data
assert data["accumulated_cost"] == 0.25
assert data["total_prompt_tokens"] == 300
assert data["total_completion_tokens"] == 125
assert data["total_tokens"] == 425
# Verify the get_metrics method was called
mock_conversation.event_stream.get_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_metrics():
"""Test that the metrics endpoint handles the case where no metrics are available."""
# Create a mock request with a conversation that has no metrics
mock_event_stream = MagicMock()
mock_event_stream.get_metrics.return_value = None
mock_conversation = MagicMock()
mock_conversation.event_stream = mock_event_stream
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the metrics data
assert data["accumulated_cost"] == 0.0
assert data["total_prompt_tokens"] == 0
assert data["total_completion_tokens"] == 0
assert data["total_tokens"] == 0
# Verify the get_metrics method was called
mock_conversation.event_stream.get_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_conversation():
"""Test that the metrics endpoint handles the case where no conversation is found."""
# Create a mock request with no conversation attribute
mock_request = MagicMock()
mock_request.state = MagicMock()
# Intentionally not setting request.state.conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the error message
assert "error" in data
assert "No conversation found" in data["error"] or "conversation" in data["error"].lower()
@pytest.mark.asyncio
async def test_get_conversation_metrics_exception():
"""Test that the metrics endpoint handles exceptions gracefully."""
# Create a mock request with a conversation that raises an exception
mock_event_stream = MagicMock()
mock_event_stream.get_metrics.side_effect = Exception("Test exception")
mock_conversation = MagicMock()
mock_conversation.event_stream = mock_event_stream
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the error message
assert "error" in data
assert "Test exception" in data["error"]
# Verify the get_metrics method was called
mock_conversation.event_stream.get_metrics.assert_called_once()