mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
[Refactor]: Add LLMRegistry for llm services (#9589)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
from openhands.core.config.agent_config import AgentConfig
|
||||
from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.llm_registry import LLMRegistry
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||
ActionExecutionClient,
|
||||
)
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
@@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a properly configured mock agent with all required nested attributes"""
|
||||
# Create the base mocks
|
||||
agent = MagicMock(spec=Agent)
|
||||
llm = MagicMock(spec=LLM)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
llm_config = MagicMock(spec=LLMConfig)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
def mock_llm_registry():
|
||||
"""Create a mock LLM registry that properly simulates LLM registration"""
|
||||
config = OpenHandsConfig()
|
||||
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||
return registry
|
||||
|
||||
# Configure the LLM config
|
||||
llm_config.model = 'test-model'
|
||||
llm_config.base_url = 'http://test'
|
||||
llm_config.max_message_chars = 1000
|
||||
|
||||
# Configure the agent config
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
@pytest.fixture
|
||||
def mock_conversation_stats():
|
||||
"""Create a mock ConversationStats that properly simulates metrics tracking"""
|
||||
file_store = InMemoryFileStore({})
|
||||
stats = ConversationStats(
|
||||
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||
)
|
||||
return stats
|
||||
|
||||
# Set up the chain of mocks
|
||||
llm.metrics = metrics
|
||||
llm.config = llm_config
|
||||
agent.llm = llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
|
||||
return agent
|
||||
@pytest.fixture
|
||||
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
|
||||
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||
# Subscribe to LLM registry events to track metrics
|
||||
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
|
||||
return mock_llm_registry, mock_conversation_stats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_mock_agent():
|
||||
def _make_mock_agent(llm_registry):
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent_config = MagicMock(spec=AgentConfig)
|
||||
llm_config = LLMConfig(
|
||||
model='gpt-4o',
|
||||
api_key='test_key',
|
||||
num_retries=2,
|
||||
retry_min_wait=1,
|
||||
retry_max_wait=2,
|
||||
)
|
||||
agent_config.disabled_microagents = []
|
||||
agent_config.enable_mcp = True
|
||||
llm_registry.service_to_llm.clear()
|
||||
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||
agent.llm = mock_llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
agent.config = agent_config
|
||||
agent.prompt_manager = MagicMock()
|
||||
return agent
|
||||
|
||||
return _make_mock_agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_no_state(mock_agent):
|
||||
async def test_agent_session_start_with_no_state(
|
||||
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
async def test_agent_session_start_with_restored_state(
|
||||
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
"""Test that metrics are centralized and shared between controller and agent."""
|
||||
async def test_metrics_centralization_via_conversation_stats(
|
||||
make_mock_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Test that metrics are centralized through the ConversationStats service."""
|
||||
|
||||
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# The registry already has a real metrics object set up in the fixture
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
@@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
# Verify that the agent's LLM metrics and controller's state metrics are the same object
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
# Verify that the ConversationStats is properly set up
|
||||
assert session.controller.state.convo_stats is mock_conversation_stats
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
# Verify that the cost is reflected in the controller's state metrics
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost
|
||||
# Verify that the cost is reflected in the combined metrics from the conversation stats
|
||||
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||
assert combined_metrics.accumulated_cost == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
# Add more cost to simulate additional LLM usage
|
||||
additional_cost = 0.1
|
||||
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||
|
||||
# Get the current accumulated cost before merging
|
||||
current_cost = session.controller.state.metrics.accumulated_cost
|
||||
# Verify the combined metrics reflect the total cost
|
||||
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||
assert combined_metrics.accumulated_cost == test_cost + additional_cost
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
|
||||
# Verify that the merged metrics are reflected in both agent and controller
|
||||
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
|
||||
assert (
|
||||
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
|
||||
)
|
||||
|
||||
# Reset the agent and verify that metrics are not reset
|
||||
# Reset the agent and verify that combined metrics are preserved
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Metrics should still be the same after reset
|
||||
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
|
||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
||||
# Combined metrics should still be preserved after agent reset
|
||||
assert (
|
||||
session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
async def test_budget_control_flag_syncs_with_metrics(
|
||||
make_mock_agent, connected_registry_and_stats
|
||||
):
|
||||
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
|
||||
|
||||
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||
mock_agent = make_mock_agent(mock_llm_registry)
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
@@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||
memory.microagents_dir = 'test-dir'
|
||||
|
||||
# The registry already has a real metrics object set up in the fixture
|
||||
|
||||
# Patch necessary components
|
||||
with (
|
||||
patch(
|
||||
@@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
assert session.controller.state.budget_flag.max_value == 1.0
|
||||
assert session.controller.state.budget_flag.current_value == 0.0
|
||||
|
||||
# Add some metrics to the agent's LLM
|
||||
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||
test_cost = 0.05
|
||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||
|
||||
@@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert session.controller.state.budget_flag.current_value == test_cost
|
||||
|
||||
# Create a test metrics object to simulate an observation with metrics
|
||||
test_observation_metrics = Metrics()
|
||||
test_observation_metrics.add_cost(0.1)
|
||||
# Add more cost to simulate additional LLM usage
|
||||
additional_cost = 0.1
|
||||
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||
|
||||
# Simulate merging metrics from an observation
|
||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
||||
# Sync again and verify the budget flag is updated
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert (
|
||||
session.controller.state.budget_flag.current_value
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
# Verify that the budget control flag's current value is updated to match the new accumulated cost
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
|
||||
# Reset the agent and verify that metrics and budget flag are not reset
|
||||
# Reset the agent and verify that budget flag still reflects the accumulated cost
|
||||
session.controller.agent.reset()
|
||||
|
||||
# Budget control flag should still reflect the accumulated cost after reset
|
||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||
assert (
|
||||
session.controller.state.budget_flag.current_value
|
||||
== test_cost + additional_cost
|
||||
)
|
||||
|
||||
|
||||
def test_override_provider_tokens_with_custom_secret():
|
||||
def test_override_provider_tokens_with_custom_secret(
|
||||
mock_llm_registry, mock_conversation_stats
|
||||
):
|
||||
"""Test that override_provider_tokens_with_custom_secret works correctly.
|
||||
|
||||
This test verifies that the method properly removes provider tokens when
|
||||
@@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
|
||||
session = AgentSession(
|
||||
sid='test-session',
|
||||
file_store=file_store,
|
||||
llm_registry=mock_llm_registry,
|
||||
convo_stats=mock_conversation_stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
|
||||
Reference in New Issue
Block a user