fix(backend): conversation statistics are currently not being persisted to the database (V1). (#11837)

This commit is contained in:
Hiep Le
2025-12-02 21:22:02 +07:00
committed by GitHub
parent 1f9350320f
commit f76ac242f0
4 changed files with 763 additions and 0 deletions

View File

@@ -9,6 +9,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationSortOrder,
)
from openhands.app_server.services.injector import Injector
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.sdk.utils.models import DiscriminatedUnionMixin
@@ -92,6 +93,19 @@ class AppConversationInfoService(ABC):
Return the stored info
"""
@abstractmethod
async def process_stats_event(
self,
event: ConversationStateUpdateEvent,
conversation_id: UUID,
) -> None:
"""Process a stats event and update conversation statistics.
Args:
event: The ConversationStateUpdateEvent with key='stats'
conversation_id: The ID of the conversation to update
"""
class AppConversationInfoServiceInjector(
DiscriminatedUnionMixin, Injector[AppConversationInfoService], ABC

View File

@@ -45,6 +45,8 @@ from openhands.app_server.utils.sql_utils import (
create_json_type_decorator,
)
from openhands.integrations.provider import ProviderType
from openhands.sdk.conversation.conversation_stats import ConversationStats
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@@ -354,6 +356,130 @@ class SQLAppConversationInfoService(AppConversationInfoService):
await self.db_session.commit()
return info
async def update_conversation_statistics(
self, conversation_id: UUID, stats: ConversationStats
) -> None:
"""Update conversation statistics from stats event data.
Args:
conversation_id: The ID of the conversation to update
stats: ConversationStats object containing usage_to_metrics data from stats event
"""
# Extract agent metrics from usage_to_metrics
usage_to_metrics = stats.usage_to_metrics
agent_metrics = usage_to_metrics.get('agent')
if not agent_metrics:
logger.debug(
'No agent metrics found in stats for conversation %s', conversation_id
)
return
# Query existing record using secure select (filters for V1 and user if available)
query = await self._secure_select()
query = query.where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
if not stored:
logger.debug(
'Conversation %s not found or not accessible, skipping statistics update',
conversation_id,
)
return
# Extract accumulated_cost and max_budget_per_task from Metrics object
accumulated_cost = agent_metrics.accumulated_cost
max_budget_per_task = agent_metrics.max_budget_per_task
# Extract accumulated_token_usage from Metrics object
accumulated_token_usage = agent_metrics.accumulated_token_usage
if accumulated_token_usage:
prompt_tokens = accumulated_token_usage.prompt_tokens
completion_tokens = accumulated_token_usage.completion_tokens
cache_read_tokens = accumulated_token_usage.cache_read_tokens
cache_write_tokens = accumulated_token_usage.cache_write_tokens
reasoning_tokens = accumulated_token_usage.reasoning_tokens
context_window = accumulated_token_usage.context_window
per_turn_token = accumulated_token_usage.per_turn_token
else:
prompt_tokens = None
completion_tokens = None
cache_read_tokens = None
cache_write_tokens = None
reasoning_tokens = None
context_window = None
per_turn_token = None
# Update fields only if values are provided (not None)
if accumulated_cost is not None:
stored.accumulated_cost = accumulated_cost
if max_budget_per_task is not None:
stored.max_budget_per_task = max_budget_per_task
if prompt_tokens is not None:
stored.prompt_tokens = prompt_tokens
if completion_tokens is not None:
stored.completion_tokens = completion_tokens
if cache_read_tokens is not None:
stored.cache_read_tokens = cache_read_tokens
if cache_write_tokens is not None:
stored.cache_write_tokens = cache_write_tokens
if reasoning_tokens is not None:
stored.reasoning_tokens = reasoning_tokens
if context_window is not None:
stored.context_window = context_window
if per_turn_token is not None:
stored.per_turn_token = per_turn_token
# Update last_updated_at timestamp
stored.last_updated_at = utc_now()
await self.db_session.commit()
async def process_stats_event(
self,
event: ConversationStateUpdateEvent,
conversation_id: UUID,
) -> None:
"""Process a stats event and update conversation statistics.
Args:
event: The ConversationStateUpdateEvent with key='stats'
conversation_id: The ID of the conversation to update
"""
try:
# Parse event value into ConversationStats model for type safety
# event.value can be a dict (from JSON deserialization) or a ConversationStats object
event_value = event.value
conversation_stats: ConversationStats | None = None
if isinstance(event_value, ConversationStats):
# Already a ConversationStats object
conversation_stats = event_value
elif isinstance(event_value, dict):
# Parse dict into ConversationStats model
# This validates the structure and ensures type safety
conversation_stats = ConversationStats.model_validate(event_value)
elif hasattr(event_value, 'usage_to_metrics'):
# Handle objects with usage_to_metrics attribute (e.g., from tests)
# Convert to dict first, then validate
stats_dict = {'usage_to_metrics': event_value.usage_to_metrics}
conversation_stats = ConversationStats.model_validate(stats_dict)
if conversation_stats and conversation_stats.usage_to_metrics:
# Pass ConversationStats object directly for type safety
await self.update_conversation_statistics(
conversation_id, conversation_stats
)
except Exception:
logger.exception(
'Error updating conversation statistics for conversation %s',
conversation_id,
stack_info=True,
)
async def _secure_select(self):
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'

View File

@@ -43,6 +43,7 @@ from openhands.app_server.user.specifiy_user_context import (
from openhands.app_server.user.user_context import UserContext
from openhands.integrations.provider import ProviderType
from openhands.sdk import Event
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.server.user_auth.default_user_auth import DefaultUserAuth
from openhands.server.user_auth.user_auth import (
get_for_user as get_user_auth_for_user,
@@ -144,6 +145,13 @@ async def on_event(
*[event_service.save_event(conversation_id, event) for event in events]
)
# Process stats events for V1 conversations
for event in events:
if isinstance(event, ConversationStateUpdateEvent) and event.key == 'stats':
await app_conversation_info_service.process_stats_event(
event, conversation_id
)
asyncio.create_task(
_run_callbacks_in_bg_and_close(
conversation_id, app_conversation_info.created_by_user_id, events

View File

@@ -0,0 +1,615 @@
"""Tests for stats event processing in webhook_router.
This module tests the stats event processing functionality introduced for
updating conversation statistics from ConversationStateUpdateEvent events.
"""
from datetime import datetime, timezone
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
StoredConversationMetadata,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.sdk.conversation.conversation_stats import ConversationStats
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.sdk.llm.utils.metrics import Metrics, TokenUsage
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
def service(async_session) -> SQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance for testing."""
return SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
async def v1_conversation_metadata(async_session, service):
"""Create a V1 conversation metadata record for testing."""
conversation_id = uuid4()
stored = StoredConversationMetadata(
conversation_id=str(conversation_id),
user_id='test_user_123',
sandbox_id='sandbox_123',
conversation_version='V1',
title='Test Conversation',
accumulated_cost=0.0,
prompt_tokens=0,
completion_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
reasoning_tokens=0,
context_window=0,
per_turn_token=0,
created_at=datetime.now(timezone.utc),
last_updated_at=datetime.now(timezone.utc),
)
async_session.add(stored)
await async_session.commit()
return conversation_id, stored
@pytest.fixture
def stats_event_with_dict_value():
"""Create a ConversationStateUpdateEvent with dict value."""
event_value = {
'usage_to_metrics': {
'agent': {
'accumulated_cost': 0.03411525,
'max_budget_per_task': None,
'accumulated_token_usage': {
'prompt_tokens': 8770,
'completion_tokens': 82,
'cache_read_tokens': 0,
'cache_write_tokens': 8767,
'reasoning_tokens': 0,
'context_window': 0,
'per_turn_token': 8852,
},
},
'condenser': {
'accumulated_cost': 0.0,
'accumulated_token_usage': {
'prompt_tokens': 0,
'completion_tokens': 0,
},
},
}
}
return ConversationStateUpdateEvent(key='stats', value=event_value)
@pytest.fixture
def stats_event_with_object_value():
"""Create a ConversationStateUpdateEvent with object value."""
event_value = MagicMock()
event_value.usage_to_metrics = {
'agent': {
'accumulated_cost': 0.05,
'accumulated_token_usage': {
'prompt_tokens': 1000,
'completion_tokens': 100,
},
}
}
return ConversationStateUpdateEvent(key='stats', value=event_value)
@pytest.fixture
def stats_event_no_usage_to_metrics():
"""Create a ConversationStateUpdateEvent without usage_to_metrics."""
event_value = {'some_other_key': 'value'}
return ConversationStateUpdateEvent(key='stats', value=event_value)
# ---------------------------------------------------------------------------
# Tests for update_conversation_statistics
# ---------------------------------------------------------------------------
class TestUpdateConversationStatistics:
"""Test the update_conversation_statistics method."""
@pytest.mark.asyncio
async def test_update_statistics_success(
self, service, async_session, v1_conversation_metadata
):
"""Test successfully updating conversation statistics."""
conversation_id, stored = v1_conversation_metadata
agent_metrics = Metrics(
model_name='test-model',
accumulated_cost=0.03411525,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
model='test-model',
prompt_tokens=8770,
completion_tokens=82,
cache_read_tokens=0,
cache_write_tokens=8767,
reasoning_tokens=0,
context_window=0,
per_turn_token=8852,
),
)
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
await service.update_conversation_statistics(conversation_id, stats)
# Verify the update
await async_session.refresh(stored)
assert stored.accumulated_cost == 0.03411525
assert stored.max_budget_per_task == 10.0
assert stored.prompt_tokens == 8770
assert stored.completion_tokens == 82
assert stored.cache_read_tokens == 0
assert stored.cache_write_tokens == 8767
assert stored.reasoning_tokens == 0
assert stored.context_window == 0
assert stored.per_turn_token == 8852
assert stored.last_updated_at is not None
@pytest.mark.asyncio
async def test_update_statistics_partial_update(
self, service, async_session, v1_conversation_metadata
):
"""Test updating only some statistics fields."""
conversation_id, stored = v1_conversation_metadata
# Set initial values
stored.accumulated_cost = 0.01
stored.prompt_tokens = 100
await async_session.commit()
agent_metrics = Metrics(
model_name='test-model',
accumulated_cost=0.05,
accumulated_token_usage=TokenUsage(
model='test-model',
prompt_tokens=200,
completion_tokens=0, # Default value
),
)
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
await service.update_conversation_statistics(conversation_id, stats)
# Verify updated fields
await async_session.refresh(stored)
assert stored.accumulated_cost == 0.05
assert stored.prompt_tokens == 200
# completion_tokens should remain unchanged (not None in stats)
assert stored.completion_tokens == 0
@pytest.mark.asyncio
async def test_update_statistics_no_agent_metrics(
self, service, v1_conversation_metadata
):
"""Test that update is skipped when no agent metrics are present."""
conversation_id, stored = v1_conversation_metadata
original_cost = stored.accumulated_cost
condenser_metrics = Metrics(
model_name='test-model',
accumulated_cost=0.1,
)
stats = ConversationStats(usage_to_metrics={'condenser': condenser_metrics})
await service.update_conversation_statistics(conversation_id, stats)
# Verify no update occurred
assert stored.accumulated_cost == original_cost
@pytest.mark.asyncio
async def test_update_statistics_conversation_not_found(self, service):
"""Test that update is skipped when conversation doesn't exist."""
nonexistent_id = uuid4()
agent_metrics = Metrics(
model_name='test-model',
accumulated_cost=0.1,
)
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
# Should not raise an exception
await service.update_conversation_statistics(nonexistent_id, stats)
@pytest.mark.asyncio
async def test_update_statistics_v0_conversation_skipped(
self, service, async_session
):
"""Test that V0 conversations are skipped."""
conversation_id = uuid4()
stored = StoredConversationMetadata(
conversation_id=str(conversation_id),
user_id='test_user_123',
sandbox_id='sandbox_123',
conversation_version='V0', # V0 conversation
title='V0 Conversation',
accumulated_cost=0.0,
created_at=datetime.now(timezone.utc),
last_updated_at=datetime.now(timezone.utc),
)
async_session.add(stored)
await async_session.commit()
original_cost = stored.accumulated_cost
agent_metrics = Metrics(
model_name='test-model',
accumulated_cost=0.1,
)
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
await service.update_conversation_statistics(conversation_id, stats)
# Verify no update occurred
await async_session.refresh(stored)
assert stored.accumulated_cost == original_cost
@pytest.mark.asyncio
async def test_update_statistics_with_none_values(
self, service, async_session, v1_conversation_metadata
):
"""Test that None values in stats don't overwrite existing values."""
conversation_id, stored = v1_conversation_metadata
# Set initial values
stored.accumulated_cost = 0.01
stored.max_budget_per_task = 5.0
stored.prompt_tokens = 100
await async_session.commit()
agent_metrics = Metrics(
model_name='test-model',
accumulated_cost=0.05,
max_budget_per_task=None, # None value
accumulated_token_usage=TokenUsage(
model='test-model',
prompt_tokens=200,
completion_tokens=0, # Default value (None is not valid for int)
),
)
stats = ConversationStats(usage_to_metrics={'agent': agent_metrics})
await service.update_conversation_statistics(conversation_id, stats)
# Verify updated fields and that None values didn't overwrite
await async_session.refresh(stored)
assert stored.accumulated_cost == 0.05
assert stored.max_budget_per_task == 5.0 # Should remain unchanged
assert stored.prompt_tokens == 200
assert (
stored.completion_tokens == 0
) # Should remain unchanged (was 0, None doesn't update)
# ---------------------------------------------------------------------------
# Tests for process_stats_event
# ---------------------------------------------------------------------------
class TestProcessStatsEvent:
"""Test the process_stats_event method."""
@pytest.mark.asyncio
async def test_process_stats_event_with_dict_value(
self,
service,
async_session,
stats_event_with_dict_value,
v1_conversation_metadata,
):
"""Test processing stats event with dict value."""
conversation_id, stored = v1_conversation_metadata
await service.process_stats_event(stats_event_with_dict_value, conversation_id)
# Verify the update occurred
await async_session.refresh(stored)
assert stored.accumulated_cost == 0.03411525
assert stored.prompt_tokens == 8770
assert stored.completion_tokens == 82
@pytest.mark.asyncio
async def test_process_stats_event_with_object_value(
self,
service,
async_session,
stats_event_with_object_value,
v1_conversation_metadata,
):
"""Test processing stats event with object value."""
conversation_id, stored = v1_conversation_metadata
await service.process_stats_event(
stats_event_with_object_value, conversation_id
)
# Verify the update occurred
await async_session.refresh(stored)
assert stored.accumulated_cost == 0.05
assert stored.prompt_tokens == 1000
assert stored.completion_tokens == 100
@pytest.mark.asyncio
async def test_process_stats_event_no_usage_to_metrics(
self,
service,
async_session,
stats_event_no_usage_to_metrics,
v1_conversation_metadata,
):
"""Test processing stats event without usage_to_metrics."""
conversation_id, stored = v1_conversation_metadata
original_cost = stored.accumulated_cost
await service.process_stats_event(
stats_event_no_usage_to_metrics, conversation_id
)
# Verify update_conversation_statistics was NOT called
await async_session.refresh(stored)
assert stored.accumulated_cost == original_cost
@pytest.mark.asyncio
async def test_process_stats_event_service_error_handled(
self, service, stats_event_with_dict_value
):
"""Test that errors from service are caught and logged."""
conversation_id = uuid4()
# Should not raise an exception
with (
patch.object(
service,
'update_conversation_statistics',
side_effect=Exception('Database error'),
),
patch(
'openhands.app_server.app_conversation.sql_app_conversation_info_service.logger'
) as mock_logger,
):
await service.process_stats_event(
stats_event_with_dict_value, conversation_id
)
# Verify error was logged
mock_logger.exception.assert_called_once()
@pytest.mark.asyncio
async def test_process_stats_event_empty_usage_to_metrics(
self, service, async_session, v1_conversation_metadata
):
"""Test processing stats event with empty usage_to_metrics."""
conversation_id, stored = v1_conversation_metadata
original_cost = stored.accumulated_cost
# Create event with empty usage_to_metrics
event = ConversationStateUpdateEvent(
key='stats', value={'usage_to_metrics': {}}
)
await service.process_stats_event(event, conversation_id)
# Empty dict is falsy, so update_conversation_statistics should NOT be called
await async_session.refresh(stored)
assert stored.accumulated_cost == original_cost
# ---------------------------------------------------------------------------
# Integration tests for on_event endpoint
# ---------------------------------------------------------------------------
class TestOnEventStatsProcessing:
"""Test stats event processing in the on_event endpoint."""
@pytest.mark.asyncio
async def test_on_event_processes_stats_events(self):
"""Test that on_event processes stats events."""
from openhands.app_server.event_callback.webhook_router import on_event
from openhands.app_server.sandbox.sandbox_models import (
SandboxInfo,
SandboxStatus,
)
conversation_id = uuid4()
sandbox_id = 'sandbox_123'
# Create stats event
stats_event = ConversationStateUpdateEvent(
key='stats',
value={
'usage_to_metrics': {
'agent': {
'accumulated_cost': 0.1,
'accumulated_token_usage': {
'prompt_tokens': 1000,
},
}
}
},
)
# Create non-stats event
other_event = ConversationStateUpdateEvent(
key='execution_status', value='running'
)
events = [stats_event, other_event]
# Mock dependencies
mock_sandbox = SandboxInfo(
id=sandbox_id,
status=SandboxStatus.RUNNING,
session_api_key='test_key',
created_by_user_id='user_123',
sandbox_spec_id='spec_123',
)
mock_app_conversation_info = AppConversationInfo(
id=conversation_id,
sandbox_id=sandbox_id,
created_by_user_id='user_123',
)
mock_event_service = AsyncMock()
mock_app_conversation_info_service = AsyncMock()
mock_app_conversation_info_service.get_app_conversation_info.return_value = (
mock_app_conversation_info
)
# Set up process_stats_event to call update_conversation_statistics
async def process_stats_event_side_effect(event, conversation_id):
# Simulate what process_stats_event does - call update_conversation_statistics
from openhands.sdk.conversation.conversation_stats import ConversationStats
if isinstance(event.value, dict):
stats = ConversationStats.model_validate(event.value)
if stats and stats.usage_to_metrics:
await mock_app_conversation_info_service.update_conversation_statistics(
conversation_id, stats
)
mock_app_conversation_info_service.process_stats_event.side_effect = (
process_stats_event_side_effect
)
with (
patch(
'openhands.app_server.event_callback.webhook_router.valid_sandbox',
return_value=mock_sandbox,
),
patch(
'openhands.app_server.event_callback.webhook_router.valid_conversation',
return_value=mock_app_conversation_info,
),
patch(
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
) as mock_callbacks,
):
await on_event(
events=events,
conversation_id=conversation_id,
sandbox_info=mock_sandbox,
app_conversation_info_service=mock_app_conversation_info_service,
event_service=mock_event_service,
)
# Verify events were saved
assert mock_event_service.save_event.call_count == 2
# Verify stats event was processed
mock_app_conversation_info_service.update_conversation_statistics.assert_called_once()
# Verify callbacks were scheduled
mock_callbacks.assert_called_once()
@pytest.mark.asyncio
async def test_on_event_skips_non_stats_events(self):
"""Test that on_event skips non-stats events."""
from openhands.app_server.event_callback.webhook_router import on_event
from openhands.app_server.sandbox.sandbox_models import (
SandboxInfo,
SandboxStatus,
)
from openhands.events.action.message import MessageAction
conversation_id = uuid4()
sandbox_id = 'sandbox_123'
# Create non-stats events
events = [
ConversationStateUpdateEvent(key='execution_status', value='running'),
MessageAction(content='test'),
]
mock_sandbox = SandboxInfo(
id=sandbox_id,
status=SandboxStatus.RUNNING,
session_api_key='test_key',
created_by_user_id='user_123',
sandbox_spec_id='spec_123',
)
mock_app_conversation_info = AppConversationInfo(
id=conversation_id,
sandbox_id=sandbox_id,
created_by_user_id='user_123',
)
mock_event_service = AsyncMock()
mock_app_conversation_info_service = AsyncMock()
mock_app_conversation_info_service.get_app_conversation_info.return_value = (
mock_app_conversation_info
)
with (
patch(
'openhands.app_server.event_callback.webhook_router.valid_sandbox',
return_value=mock_sandbox,
),
patch(
'openhands.app_server.event_callback.webhook_router.valid_conversation',
return_value=mock_app_conversation_info,
),
patch(
'openhands.app_server.event_callback.webhook_router._run_callbacks_in_bg_and_close'
),
):
await on_event(
events=events,
conversation_id=conversation_id,
sandbox_info=mock_sandbox,
app_conversation_info_service=mock_app_conversation_info_service,
event_service=mock_event_service,
)
# Verify stats update was NOT called
mock_app_conversation_info_service.update_conversation_statistics.assert_not_called()