Compare commits

...

3 Commits

Author SHA1 Message Date
openhands
9a8d6f1f1c Fix #7105: Expose metrics directly on conversation rather than inside events 2025-03-04 21:51:22 +00:00
openhands
7b54eb3ab8 Add comprehensive tests for conversation metrics API 2025-03-04 21:40:48 +00:00
openhands
c3cc4c71f9 Fix unreachable code in get_metrics method 2025-03-04 21:38:06 +00:00
6 changed files with 668 additions and 0 deletions

View File

@@ -428,3 +428,41 @@ class EventStream:
break
return matching_events
def get_metrics(self):
"""Get the accumulated metrics from all events in the stream.
This method extracts metrics from events that contain them and returns
the aggregated metrics object.
Returns:
Metrics: The metrics object containing accumulated cost and token usage data.
Returns None if no metrics are found.
"""
from openhands.llm.metrics import Metrics
# Look for events with metrics
metrics = None
events_with_metrics = []
try:
# First collect all events with metrics
for event in self.get_events():
if hasattr(event, 'llm_metrics') and event.llm_metrics is not None:
events_with_metrics.append(event)
# Then merge them if any were found
if events_with_metrics:
# Get the first event with metrics to initialize our metrics object
first_event = events_with_metrics[0]
if first_event.llm_metrics is not None:
metrics = Metrics(model_name=first_event.llm_metrics.model_name)
# Merge metrics from all events
for event in events_with_metrics:
if event.llm_metrics is not None:
metrics.merge(event.llm_metrics)
except Exception as e:
logger.error(f'Error retrieving metrics from events: {e}')
return metrics

View File

@@ -7,6 +7,74 @@ from openhands.runtime.base import Runtime
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@app.get('/metrics')
async def get_conversation_metrics(request: Request):
"""Retrieve the conversation metrics.
This endpoint returns the accumulated cost and token usage metrics for the conversation.
Metrics are retrieved directly from the runtime's state rather than reconstructing from events,
providing a more accurate representation of costs, including those not associated with events.
Args:
request (Request): The incoming FastAPI request object.
Returns:
JSONResponse: A JSON response containing the metrics data.
"""
try:
if not hasattr(request.state, 'conversation'):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'error': 'No conversation found in request state'},
)
conversation = request.state.conversation
# Get metrics directly from the conversation's runtime state
metrics = conversation.get_metrics()
# If no metrics from state, fall back to event stream metrics for backward compatibility
if not metrics and hasattr(conversation.event_stream, 'get_metrics'):
metrics = conversation.event_stream.get_metrics()
if not metrics:
# Return empty metrics if not available
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
'accumulated_cost': 0.0,
'total_prompt_tokens': 0,
'total_completion_tokens': 0,
'total_tokens': 0,
},
)
# Calculate total tokens
total_prompt_tokens = sum(usage.prompt_tokens for usage in metrics.token_usages)
total_completion_tokens = sum(
usage.completion_tokens for usage in metrics.token_usages
)
total_tokens = total_prompt_tokens + total_completion_tokens
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
'accumulated_cost': metrics.accumulated_cost,
'total_prompt_tokens': total_prompt_tokens,
'total_completion_tokens': total_completion_tokens,
'total_tokens': total_tokens,
},
)
except Exception as e:
logger.error(f'Error getting conversation metrics: {e}')
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'error': f'Error getting conversation metrics: {e}',
},
)
@app.get('/config')
async def get_remote_runtime_config(request: Request):
"""Retrieve the runtime configuration.

View File

@@ -46,3 +46,24 @@ class Conversation:
if self.event_stream:
self.event_stream.close()
asyncio.create_task(call_sync_from_async(self.runtime.close))
def get_metrics(self):
"""Get metrics directly from the runtime's state.
This method retrieves metrics from the runtime's state object rather than
reconstructing them from events, providing a more accurate representation
of costs and token usage, including those not associated with events.
Returns:
Metrics: The metrics object containing accumulated cost and token usage data.
Returns None if no metrics are available or if the runtime has no state.
"""
try:
if hasattr(self.runtime, 'state') and self.runtime.state:
return self.runtime.state.metrics
return None
except Exception as e:
from openhands.core.logger import openhands_logger as logger
logger.error(f'Error retrieving metrics from runtime state: {e}')
return None

View File

@@ -0,0 +1,129 @@
from unittest.mock import MagicMock, patch
from openhands.events.event import Event
from openhands.events.stream import EventStream
from openhands.llm.metrics import Metrics
class TestEventStream:
def test_get_metrics_empty_stream(self):
"""Test that get_metrics returns None for an empty stream."""
sid = 'test-stream-id'
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
assert stream.get_metrics() is None
def test_get_metrics_no_metrics_in_events(self):
"""Test that get_metrics returns None when no events have metrics."""
sid = 'test-stream-id'
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
event = MagicMock(spec=Event)
event.llm_metrics = None
with patch.object(stream, 'get_events', return_value=[event]):
assert stream.get_metrics() is None
def test_get_metrics_with_metrics(self):
"""Test that get_metrics correctly aggregates metrics from events."""
sid = 'test-stream-id'
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
# Create mock events with metrics
event1 = MagicMock(spec=Event)
metrics1 = Metrics(model_name='gpt-4')
metrics1.add_token_usage(
prompt_tokens=10,
completion_tokens=20,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='resp1',
)
event1.llm_metrics = metrics1
event2 = MagicMock(spec=Event)
metrics2 = Metrics(model_name='gpt-4')
metrics2.add_token_usage(
prompt_tokens=15,
completion_tokens=25,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='resp2',
)
event2.llm_metrics = metrics2
with patch.object(stream, 'get_events', return_value=[event1, event2]):
result = stream.get_metrics()
assert result is not None
assert result.model_name == 'gpt-4'
# Check token usages are merged correctly
total_prompt_tokens = sum(
usage.prompt_tokens for usage in result.token_usages
)
total_completion_tokens = sum(
usage.completion_tokens for usage in result.token_usages
)
assert total_prompt_tokens == 25 # 10 + 15
assert total_completion_tokens == 45 # 20 + 25
assert len(result.token_usages) == 2
def test_get_metrics_with_exception(self):
"""Test that get_metrics handles exceptions gracefully."""
sid = 'test-stream-id'
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
with patch.object(
stream, 'get_events', side_effect=Exception('Test exception')
):
assert stream.get_metrics() is None
def test_get_metrics_with_mixed_events(self):
"""Test that get_metrics correctly handles a mix of events with and without metrics."""
sid = 'test-stream-id'
file_store = MagicMock()
stream = EventStream(sid=sid, file_store=file_store)
# Create mock events, some with metrics and some without
event1 = MagicMock(spec=Event)
metrics1 = Metrics(model_name='gpt-4')
metrics1.add_token_usage(
prompt_tokens=10,
completion_tokens=20,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='resp1',
)
event1.llm_metrics = metrics1
event2 = MagicMock(spec=Event)
event2.llm_metrics = None
event3 = MagicMock(spec=Event)
metrics3 = Metrics(model_name='gpt-4')
metrics3.add_token_usage(
prompt_tokens=15,
completion_tokens=25,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='resp3',
)
event3.llm_metrics = metrics3
with patch.object(stream, 'get_events', return_value=[event1, event2, event3]):
result = stream.get_metrics()
assert result is not None
assert result.model_name == 'gpt-4'
# Check token usages are merged correctly
total_prompt_tokens = sum(
usage.prompt_tokens for usage in result.token_usages
)
total_completion_tokens = sum(
usage.completion_tokens for usage in result.token_usages
)
assert total_prompt_tokens == 25 # 10 + 15
assert total_completion_tokens == 45 # 20 + 25
assert len(result.token_usages) == 2

View File

@@ -0,0 +1,203 @@
from unittest.mock import MagicMock
import pytest
from fastapi import Request, status
from fastapi.responses import JSONResponse
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.server.routes.conversation import get_conversation_metrics
@pytest.fixture
def mock_request():
"""Create a mock request with a conversation."""
request = MagicMock(spec=Request)
request.state.conversation = MagicMock()
request.state.conversation.runtime = MagicMock()
request.state.conversation.event_stream = MagicMock()
# Add get_metrics method to conversation mock
request.state.conversation.get_metrics = MagicMock()
return request
@pytest.mark.asyncio
async def test_get_conversation_metrics_success(mock_request):
"""Test successful retrieval of conversation metrics from runtime state."""
# Setup mock metrics
metrics = Metrics()
metrics.token_usages = [
TokenUsage(
prompt_tokens=100,
completion_tokens=50,
model='test-model',
cache_read_tokens=0,
cache_write_tokens=0,
response_id='test-response-1',
),
TokenUsage(
prompt_tokens=200,
completion_tokens=150,
model='test-model',
cache_read_tokens=0,
cache_write_tokens=0,
response_id='test-response-2',
),
]
metrics.accumulated_cost = 0.25
# Configure mock to return metrics from runtime state
mock_request.state.conversation.get_metrics.return_value = metrics
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify metrics
assert content_dict['accumulated_cost'] == 0.25
assert content_dict['total_prompt_tokens'] == 300
assert content_dict['total_completion_tokens'] == 200
assert content_dict['total_tokens'] == 500
# Verify get_metrics was called on the conversation
mock_request.state.conversation.get_metrics.assert_called_once()
# Verify event_stream.get_metrics was not called
mock_request.state.conversation.event_stream.get_metrics.assert_not_called()
@pytest.mark.asyncio
async def test_get_conversation_metrics_fallback_to_event_stream(mock_request):
"""Test fallback to event_stream metrics when runtime state metrics are not available."""
# Setup mock metrics
metrics = Metrics()
metrics.token_usages = [
TokenUsage(
prompt_tokens=100,
completion_tokens=50,
model='test-model',
cache_read_tokens=0,
cache_write_tokens=0,
response_id='test-response-1',
),
]
metrics.accumulated_cost = 0.15
# Configure mock to return None from runtime state but metrics from event_stream
mock_request.state.conversation.get_metrics.return_value = None
mock_request.state.conversation.event_stream.get_metrics.return_value = metrics
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify metrics from event_stream were used
assert content_dict['accumulated_cost'] == 0.15
assert content_dict['total_prompt_tokens'] == 100
assert content_dict['total_completion_tokens'] == 50
assert content_dict['total_tokens'] == 150
# Verify both methods were called
mock_request.state.conversation.get_metrics.assert_called_once()
mock_request.state.conversation.event_stream.get_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_metrics(mock_request):
"""Test handling when no metrics are available from either source."""
# Configure mocks to return None for metrics
mock_request.state.conversation.get_metrics.return_value = None
mock_request.state.conversation.event_stream.get_metrics.return_value = None
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify default metrics
assert content_dict['accumulated_cost'] == 0.0
assert content_dict['total_prompt_tokens'] == 0
assert content_dict['total_completion_tokens'] == 0
assert content_dict['total_tokens'] == 0
# Verify both methods were called
mock_request.state.conversation.get_metrics.assert_called_once()
mock_request.state.conversation.event_stream.get_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_conversation():
"""Test handling when no conversation is found in request state."""
# Create a request without conversation
request = MagicMock(spec=Request)
request.state = MagicMock()
# Remove conversation attribute
delattr(request.state, 'conversation')
# Call the endpoint
response = await get_conversation_metrics(request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify error message
assert 'error' in content_dict
assert 'No conversation found in request state' in content_dict['error']
@pytest.mark.asyncio
async def test_get_conversation_metrics_exception(mock_request):
"""Test handling when an exception occurs."""
# Configure mock to raise an exception
mock_request.state.conversation.get_metrics.side_effect = Exception(
'Test exception'
)
# Call the endpoint
response = await get_conversation_metrics(mock_request)
# Verify response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract content from JSONResponse
content = response.body.decode('utf-8')
import json
content_dict = json.loads(content)
# Verify error message
assert 'error' in content_dict
assert 'Error getting conversation metrics' in content_dict['error']

View File

@@ -0,0 +1,209 @@
from unittest.mock import MagicMock
import pytest
from fastapi import status
from fastapi.responses import JSONResponse
from openhands.llm.metrics import Metrics, TokenUsage
from openhands.server.routes.conversation import get_conversation_metrics
@pytest.fixture
def mock_metrics():
metrics = Metrics()
metrics.accumulated_cost = 0.25
metrics.token_usages = [
TokenUsage(
model='gpt-4',
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='resp1',
),
TokenUsage(
model='gpt-4',
prompt_tokens=200,
completion_tokens=75,
cache_read_tokens=0,
cache_write_tokens=0,
response_id='resp2',
),
]
return metrics
@pytest.mark.asyncio
async def test_get_conversation_metrics_success(mock_metrics):
"""Test that the metrics endpoint returns the correct metrics data from runtime state."""
# Create a mock request with a conversation that has metrics in runtime state
mock_conversation = MagicMock()
mock_conversation.get_metrics = MagicMock(return_value=mock_metrics)
mock_conversation.event_stream = MagicMock()
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the metrics data
assert data['accumulated_cost'] == 0.25
assert data['total_prompt_tokens'] == 300
assert data['total_completion_tokens'] == 125
assert data['total_tokens'] == 425
# Verify the get_metrics method was called on conversation
mock_conversation.get_metrics.assert_called_once()
# Verify event_stream.get_metrics was not called
mock_conversation.event_stream.get_metrics.assert_not_called()
@pytest.mark.asyncio
async def test_get_conversation_metrics_fallback_to_event_stream(mock_metrics):
"""Test fallback to event_stream metrics when runtime state metrics are not available."""
# Create a mock request with a conversation that has metrics in event_stream but not in runtime state
mock_conversation = MagicMock()
mock_conversation.get_metrics = MagicMock(return_value=None)
mock_conversation.event_stream = MagicMock()
mock_conversation.event_stream.get_metrics = MagicMock(return_value=mock_metrics)
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the metrics data
assert data['accumulated_cost'] == 0.25
assert data['total_prompt_tokens'] == 300
assert data['total_completion_tokens'] == 125
assert data['total_tokens'] == 425
# Verify both methods were called
mock_conversation.get_metrics.assert_called_once()
mock_conversation.event_stream.get_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_metrics():
"""Test that the metrics endpoint handles the case where no metrics are available from either source."""
# Create a mock request with a conversation that has no metrics
mock_conversation = MagicMock()
mock_conversation.get_metrics = MagicMock(return_value=None)
mock_conversation.event_stream = MagicMock()
mock_conversation.event_stream.get_metrics = MagicMock(return_value=None)
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_200_OK
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the metrics data
assert data['accumulated_cost'] == 0.0
assert data['total_prompt_tokens'] == 0
assert data['total_completion_tokens'] == 0
assert data['total_tokens'] == 0
# Verify both methods were called
mock_conversation.get_metrics.assert_called_once()
mock_conversation.event_stream.get_metrics.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_metrics_no_conversation():
"""Test that the metrics endpoint handles the case where no conversation is found."""
# Create a mock request with no conversation attribute
mock_request = MagicMock()
mock_request.state = MagicMock()
# Intentionally not setting request.state.conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the error message
assert 'error' in data
assert (
'No conversation found' in data['error']
or 'conversation' in data['error'].lower()
)
@pytest.mark.asyncio
async def test_get_conversation_metrics_exception():
"""Test that the metrics endpoint handles exceptions gracefully."""
# Create a mock request with a conversation that raises an exception
mock_conversation = MagicMock()
mock_conversation.get_metrics.side_effect = Exception('Test exception')
mock_conversation.event_stream = MagicMock()
mock_request = MagicMock()
mock_request.state.conversation = mock_conversation
# Call the endpoint function directly
response = await get_conversation_metrics(mock_request)
# Check the response
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# Extract the content from the response
content = response.body.decode('utf-8')
import json
data = json.loads(content)
# Verify the error message
assert 'error' in data
assert 'Error getting conversation metrics' in data['error']
# Verify the get_metrics method was called
mock_conversation.get_metrics.assert_called_once()
# Event stream get_metrics should not be called since the first call raised an exception
mock_conversation.event_stream.get_metrics.assert_not_called()