[Refactor]: Add LLMRegistry for llm services (#9589)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Rohit Malhotra
2025-08-18 02:11:20 -04:00
committed by GitHub
parent 17b1a21296
commit 25d9cf2890
84 changed files with 2376 additions and 817 deletions

View File

@@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.integrations.service_types import ProviderType
from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
@@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig)
def mock_llm_registry():
"""Create a mock LLM registry that properly simulates LLM registration"""
config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Configure the agent config
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
@pytest.fixture
def mock_conversation_stats():
"""Create a mock ConversationStats that properly simulates metrics tracking"""
file_store = InMemoryFileStore({})
stats = ConversationStats(
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
)
return stats
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
@pytest.fixture
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly"""
# Subscribe to LLM registry events to track metrics
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
return mock_llm_registry, mock_conversation_stats
@pytest.fixture
def make_mock_agent():
def _make_mock_agent(llm_registry):
agent = MagicMock(spec=Agent)
agent_config = MagicMock(spec=AgentConfig)
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
llm_registry.service_to_llm.clear()
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
agent.llm = mock_llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
return _make_mock_agent
@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
async def test_agent_session_start_with_no_state(
make_mock_agent, mock_llm_registry, mock_conversation_stats
):
"""Test that AgentSession.start() works correctly when there's no state to restore"""
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
async def test_agent_session_start_with_restored_state(
make_mock_agent, mock_llm_registry, mock_conversation_stats
):
"""Test that AgentSession.start() works correctly when there's a state to restore"""
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
@pytest.mark.asyncio
async def test_metrics_centralization_and_sharing(mock_agent):
"""Test that metrics are centralized and shared between controller and agent."""
async def test_metrics_centralization_via_conversation_stats(
make_mock_agent, connected_registry_and_stats
):
"""Test that metrics are centralized through the ConversationStats service."""
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# The registry already has a real metrics object set up in the fixture
# Patch necessary components
with (
patch(
@@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
max_iterations=10,
)
# Verify that the agent's LLM metrics and controller's state metrics are the same object
assert session.controller.agent.llm.metrics is session.controller.state.metrics
# Verify that the ConversationStats is properly set up
assert session.controller.state.convo_stats is mock_conversation_stats
# Add some metrics to the agent's LLM
# Add some metrics to the agent's LLM (simulating LLM usage)
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the cost is reflected in the controller's state metrics
assert session.controller.state.metrics.accumulated_cost == test_cost
# Verify that the cost is reflected in the combined metrics from the conversation stats
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
assert combined_metrics.accumulated_cost == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Add more cost to simulate additional LLM usage
additional_cost = 0.1
session.controller.agent.llm.metrics.add_cost(additional_cost)
# Get the current accumulated cost before merging
current_cost = session.controller.state.metrics.accumulated_cost
# Verify the combined metrics reflect the total cost
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
assert combined_metrics.accumulated_cost == test_cost + additional_cost
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the merged metrics are reflected in both agent and controller
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
assert (
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
)
# Reset the agent and verify that metrics are not reset
# Reset the agent and verify that combined metrics are preserved
session.controller.agent.reset()
# Metrics should still be the same after reset
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
assert session.controller.agent.llm.metrics is session.controller.state.metrics
# Combined metrics should still be preserved after agent reset
assert (
session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
== test_cost + additional_cost
)
@pytest.mark.asyncio
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
async def test_budget_control_flag_syncs_with_metrics(
make_mock_agent, connected_registry_and_stats
):
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
mock_agent = make_mock_agent(mock_llm_registry)
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create a mock runtime and set it up
@@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# The registry already has a real metrics object set up in the fixture
# Patch necessary components
with (
patch(
@@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
assert session.controller.state.budget_flag.max_value == 1.0
assert session.controller.state.budget_flag.current_value == 0.0
# Add some metrics to the agent's LLM
# Add some metrics to the agent's LLM (simulating LLM usage)
test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost)
@@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert session.controller.state.budget_flag.current_value == test_cost
# Create a test metrics object to simulate an observation with metrics
test_observation_metrics = Metrics()
test_observation_metrics.add_cost(0.1)
# Add more cost to simulate additional LLM usage
additional_cost = 0.1
session.controller.agent.llm.metrics.add_cost(additional_cost)
# Simulate merging metrics from an observation
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Sync again and verify the budget flag is updated
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert (
session.controller.state.budget_flag.current_value
== test_cost + additional_cost
)
# Verify that the budget control flag's current value is updated to match the new accumulated cost
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
# Reset the agent and verify that metrics and budget flag are not reset
# Reset the agent and verify that budget flag still reflects the accumulated cost
session.controller.agent.reset()
# Budget control flag should still reflect the accumulated cost after reset
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
session.controller.state_tracker.sync_budget_flag_with_metrics()
assert (
session.controller.state.budget_flag.current_value
== test_cost + additional_cost
)
def test_override_provider_tokens_with_custom_secret():
def test_override_provider_tokens_with_custom_secret(
mock_llm_registry, mock_conversation_stats
):
"""Test that override_provider_tokens_with_custom_secret works correctly.
This test verifies that the method properly removes provider tokens when
@@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
session = AgentSession(
sid='test-session',
file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
)
# Create test data