[Refactor]: Add LLMRegistry for llm services (#9589)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
Rohit Malhotra
2025-08-18 02:11:20 -04:00
committed by GitHub
parent 17b1a21296
commit 25d9cf2890
84 changed files with 2376 additions and 817 deletions

View File

@@ -18,7 +18,7 @@ from openhands.events.action import (
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.events.observation import BrowserOutputObservation from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.observation import Observation from openhands.events.observation.observation import Observation
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.plugins import ( from openhands.runtime.plugins import (
PluginRequirement, PluginRequirement,
) )
@@ -102,15 +102,15 @@ class BrowsingAgent(Agent):
def __init__( def __init__(
self, self,
llm: LLM,
config: AgentConfig, config: AgentConfig,
llm_registry: LLMRegistry,
) -> None: ) -> None:
"""Initializes a new instance of the BrowsingAgent class. """Initializes a new instance of the BrowsingAgent class.
Parameters: Parameters:
- llm (LLM): The llm to be used by this agent - llm (LLM): The llm to be used by this agent
""" """
super().__init__(llm, config) super().__init__(config, llm_registry)
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML. # define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details # see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
action_subsets = ['chat', 'bid'] action_subsets = ['chat', 'bid']

View File

@@ -3,6 +3,8 @@ import sys
from collections import deque from collections import deque
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm import ChatCompletionToolParam from litellm import ChatCompletionToolParam
@@ -32,7 +34,6 @@ from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message from openhands.core.message import Message
from openhands.events.action import AgentFinishAction, MessageAction from openhands.events.action import AgentFinishAction, MessageAction
from openhands.events.event import Event from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.llm.llm_utils import check_tools from openhands.llm.llm_utils import check_tools
from openhands.memory.condenser import Condenser from openhands.memory.condenser import Condenser
from openhands.memory.condenser.condenser import Condensation, View from openhands.memory.condenser.condenser import Condensation, View
@@ -74,18 +75,13 @@ class CodeActAgent(Agent):
JupyterRequirement(), JupyterRequirement(),
] ]
def __init__( def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
self,
llm: LLM,
config: AgentConfig,
) -> None:
"""Initializes a new instance of the CodeActAgent class. """Initializes a new instance of the CodeActAgent class.
Parameters: Parameters:
- llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for this agent - config (AgentConfig): The configuration for this agent
""" """
super().__init__(llm, config) super().__init__(config, llm_registry)
self.pending_actions: deque['Action'] = deque() self.pending_actions: deque['Action'] = deque()
self.reset() self.reset()
self.tools = self._get_tools() self.tools = self._get_tools()
@@ -93,7 +89,7 @@ class CodeActAgent(Agent):
# Create a ConversationMemory instance # Create a ConversationMemory instance
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager) self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
self.condenser = Condenser.from_config(self.config.condenser) self.condenser = Condenser.from_config(self.config.condenser, llm_registry)
logger.debug(f'Using condenser: {type(self.condenser)}') logger.debug(f'Using condenser: {type(self.condenser)}')
@property @property

View File

@@ -22,7 +22,7 @@ from openhands.events.observation import (
Observation, Observation,
) )
from openhands.events.serialization.event import event_to_dict from openhands.events.serialization.event import event_to_dict
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
""" """
FIXME: There are a few problems this surfaced FIXME: There are a few problems this surfaced
@@ -42,8 +42,12 @@ class DummyAgent(Agent):
without making any LLM calls. without making any LLM calls.
""" """
def __init__(self, llm: LLM, config: AgentConfig): def __init__(
super().__init__(llm, config) self,
config: AgentConfig,
llm_registry: LLMRegistry,
):
super().__init__(config, llm_registry)
self.steps: list[ActionObs] = [ self.steps: list[ActionObs] = [
{ {
'action': MessageAction('Time to get started!'), 'action': MessageAction('Time to get started!'),

View File

@@ -4,7 +4,7 @@ import openhands.agenthub.loc_agent.function_calling as locagent_function_callin
from openhands.agenthub.codeact_agent import CodeActAgent from openhands.agenthub.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from openhands.events.action import Action from openhands.events.action import Action
@@ -16,8 +16,8 @@ class LocAgent(CodeActAgent):
def __init__( def __init__(
self, self,
llm: LLM,
config: AgentConfig, config: AgentConfig,
llm_registry: LLMRegistry,
) -> None: ) -> None:
"""Initializes a new instance of the LocAgent class. """Initializes a new instance of the LocAgent class.
@@ -25,7 +25,8 @@ class LocAgent(CodeActAgent):
- llm (LLM): The llm to be used by this agent - llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for the agent - config (AgentConfig): The configuration for the agent
""" """
super().__init__(llm, config)
super().__init__(config, llm_registry)
self.tools = locagent_function_calling.get_tools() self.tools = locagent_function_calling.get_tools()
logger.debug( logger.debug(

View File

@@ -3,6 +3,8 @@
import os import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm import ChatCompletionToolParam from litellm import ChatCompletionToolParam
@@ -15,7 +17,6 @@ from openhands.agenthub.readonly_agent import (
) )
from openhands.core.config import AgentConfig from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
from openhands.utils.prompt import PromptManager from openhands.utils.prompt import PromptManager
@@ -37,17 +38,16 @@ class ReadOnlyAgent(CodeActAgent):
def __init__( def __init__(
self, self,
llm: LLM,
config: AgentConfig, config: AgentConfig,
llm_registry: LLMRegistry,
) -> None: ) -> None:
"""Initializes a new instance of the ReadOnlyAgent class. """Initializes a new instance of the ReadOnlyAgent class.
Parameters: Parameters:
- llm (LLM): The llm to be used by this agent
- config (AgentConfig): The configuration for this agent - config (AgentConfig): The configuration for this agent
""" """
# Initialize the CodeActAgent class; some of it is overridden with class methods # Initialize the CodeActAgent class; some of it is overridden with class methods
super().__init__(llm, config) super().__init__(config, llm_registry)
logger.debug( logger.debug(
f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}' f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}'

View File

@@ -16,7 +16,7 @@ from openhands.events.action import (
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.events.observation import BrowserOutputObservation from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.observation import Observation from openhands.events.observation.observation import Observation
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.plugins import ( from openhands.runtime.plugins import (
PluginRequirement, PluginRequirement,
) )
@@ -127,17 +127,13 @@ class VisualBrowsingAgent(Agent):
sandbox_plugins: list[PluginRequirement] = [] sandbox_plugins: list[PluginRequirement] = []
response_parser = BrowsingResponseParser() response_parser = BrowsingResponseParser()
def __init__( def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
self,
llm: LLM,
config: AgentConfig,
) -> None:
"""Initializes a new instance of the VisualBrowsingAgent class. """Initializes a new instance of the VisualBrowsingAgent class.
Parameters: Parameters:
- llm (LLM): The llm to be used by this agent - llm (LLM): The llm to be used by this agent
""" """
super().__init__(llm, config) super().__init__(config, llm_registry)
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML. # define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details # see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
action_subsets = [ action_subsets = [

View File

@@ -83,6 +83,7 @@ from openhands.microagent.microagent import BaseMicroagent
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.storage.settings.file_settings_store import FileSettingsStore from openhands.storage.settings.file_settings_store import FileSettingsStore
from openhands.utils.utils import create_registry_and_convo_stats
async def cleanup_session( async def cleanup_session(
@@ -147,9 +148,16 @@ async def run_session(
None, display_initialization_animation, 'Initializing...', is_loaded None, display_initialization_animation, 'Initializing...', is_loaded
) )
agent = create_agent(config) llm_registry, convo_stats, config = create_registry_and_convo_stats(
config,
sid,
None,
)
agent = create_agent(config, llm_registry)
runtime = create_runtime( runtime = create_runtime(
config, config,
llm_registry,
sid=sid, sid=sid,
headless_mode=True, headless_mode=True,
agent=agent, agent=agent,
@@ -161,7 +169,7 @@ async def run_session(
runtime.subscribe_to_shell_stream(stream_to_console) runtime.subscribe_to_shell_stream(stream_to_console)
controller, initial_state = create_controller(agent, runtime, config) controller, initial_state = create_controller(agent, runtime, config, convo_stats)
event_stream = runtime.event_stream event_stream = runtime.event_stream

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from openhands.llm.llm_registry import LLMRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.events.action import Action from openhands.events.action import Action
@@ -17,7 +19,6 @@ from openhands.core.exceptions import (
) )
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement from openhands.runtime.plugins import PluginRequirement
@@ -38,10 +39,11 @@ class Agent(ABC):
def __init__( def __init__(
self, self,
llm: LLM,
config: AgentConfig, config: AgentConfig,
llm_registry: LLMRegistry,
): ):
self.llm = llm self.llm = llm_registry.get_llm_from_agent_config('agent', config)
self.llm_registry = llm_registry
self.config = config self.config = config
self._complete = False self._complete = False
self._prompt_manager: 'PromptManager' | None = None self._prompt_manager: 'PromptManager' | None = None

View File

@@ -73,9 +73,9 @@ from openhands.events.observation import (
Observation, Observation,
) )
from openhands.events.serialization.event import truncate_content from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics from openhands.llm.metrics import Metrics
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.files import FileStore from openhands.storage.files import FileStore
# note: RESUME is only available on web GUI # note: RESUME is only available on web GUI
@@ -109,6 +109,7 @@ class AgentController:
self, self,
agent: Agent, agent: Agent,
event_stream: EventStream, event_stream: EventStream,
convo_stats: ConversationStats,
iteration_delta: int, iteration_delta: int,
budget_per_task_delta: float | None = None, budget_per_task_delta: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None,
@@ -148,6 +149,7 @@ class AgentController:
self.agent = agent self.agent = agent
self.headless_mode = headless_mode self.headless_mode = headless_mode
self.is_delegate = is_delegate self.is_delegate = is_delegate
self.convo_stats = convo_stats
# the event stream must be set before maybe subscribing to it # the event stream must be set before maybe subscribing to it
self.event_stream = event_stream self.event_stream = event_stream
@@ -163,6 +165,7 @@ class AgentController:
# state from the previous session, state from a parent agent, or a fresh state # state from the previous session, state from a parent agent, or a fresh state
self.set_initial_state( self.set_initial_state(
state=initial_state, state=initial_state,
convo_stats=convo_stats,
max_iterations=iteration_delta, max_iterations=iteration_delta,
max_budget_per_task=budget_per_task_delta, max_budget_per_task=budget_per_task_delta,
confirmation_mode=confirmation_mode, confirmation_mode=confirmation_mode,
@@ -477,11 +480,6 @@ class AgentController:
log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'} log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'}
) )
# TODO: these metrics come from the draft editor, and they get accumulated into controller's state metrics and the agent's llm metrics
# In the future, we should have a more principled way to sharing metrics across all LLM instances for a given conversation
if observation.llm_metrics is not None:
self.state_tracker.merge_metrics(observation.llm_metrics)
# this happens for runnable actions and microagent actions # this happens for runnable actions and microagent actions
if self._pending_action and self._pending_action.id == observation.cause: if self._pending_action and self._pending_action.id == observation.cause:
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION: if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
@@ -657,14 +655,10 @@ class AgentController:
""" """
agent_cls: type[Agent] = Agent.get_cls(action.agent) agent_cls: type[Agent] = Agent.get_cls(action.agent)
agent_config = self.agent_configs.get(action.agent, self.agent.config) agent_config = self.agent_configs.get(action.agent, self.agent.config)
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
# Make sure metrics are shared between parent and child for global accumulation # Make sure metrics are shared between parent and child for global accumulation
llm = LLM( delegate_agent = agent_cls(
config=llm_config, config=agent_config, llm_registry=self.agent.llm_registry
retry_listener=self.agent.llm.retry_listener,
metrics=self.state.metrics,
) )
delegate_agent = agent_cls(llm=llm, config=agent_config)
# Take a snapshot of the current metrics before starting the delegate # Take a snapshot of the current metrics before starting the delegate
state = State( state = State(
@@ -683,7 +677,7 @@ class AgentController:
) )
self.log( self.log(
'debug', 'debug',
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}', f'start delegate, creating agent {delegate_agent.name}',
) )
# Create the delegate with is_delegate=True so it does NOT subscribe directly # Create the delegate with is_delegate=True so it does NOT subscribe directly
@@ -693,6 +687,7 @@ class AgentController:
user_id=self.user_id, user_id=self.user_id,
agent=delegate_agent, agent=delegate_agent,
event_stream=self.event_stream, event_stream=self.event_stream,
convo_stats=self.convo_stats,
iteration_delta=self._initial_max_iterations, iteration_delta=self._initial_max_iterations,
budget_per_task_delta=self._initial_max_budget_per_task, budget_per_task_delta=self._initial_max_budget_per_task,
agent_to_llm_config=self.agent_to_llm_config, agent_to_llm_config=self.agent_to_llm_config,
@@ -795,13 +790,8 @@ class AgentController:
extra={'msg_type': 'STEP'}, extra={'msg_type': 'STEP'},
) )
# Ensure budget control flag is synchronized with the latest metrics. # Synchronize spend across all llm services with the budget flag
# In the future, we should centralized the use of one LLM object per conversation.
# This will help us unify the cost for auto generating titles, running the condensor, etc.
# Before many microservices will touh the same llm cost field, we should sync with the budget flag for the controller
# and check that we haven't exceeded budget BEFORE executing an agent step.
self.state_tracker.sync_budget_flag_with_metrics() self.state_tracker.sync_budget_flag_with_metrics()
if self._is_stuck(): if self._is_stuck():
await self._react_to_exception( await self._react_to_exception(
AgentStuckInLoopError('Agent got stuck in a loop') AgentStuckInLoopError('Agent got stuck in a loop')
@@ -961,14 +951,15 @@ class AgentController:
def set_initial_state( def set_initial_state(
self, self,
state: State | None, state: State | None,
convo_stats: ConversationStats,
max_iterations: int, max_iterations: int,
max_budget_per_task: float | None, max_budget_per_task: float | None,
confirmation_mode: bool = False, confirmation_mode: bool = False,
): ):
self.state_tracker.set_initial_state( self.state_tracker.set_initial_state(
self.id, self.id,
self.agent,
state, state,
convo_stats,
max_iterations, max_iterations,
max_budget_per_task, max_budget_per_task,
confirmation_mode, confirmation_mode,
@@ -1009,37 +1000,20 @@ class AgentController:
action: The action to attach metrics to action: The action to attach metrics to
""" """
# Get metrics from agent LLM # Get metrics from agent LLM
agent_metrics = self.state.metrics metrics = self.convo_stats.get_combined_metrics()
# Get metrics from condenser LLM if it exists # Create a clean copy with only the fields we want to keep
condenser_metrics: Metrics | None = None clean_metrics = Metrics()
if hasattr(self.agent, 'condenser') and hasattr(self.agent.condenser, 'llm'): clean_metrics.accumulated_cost = metrics.accumulated_cost
condenser_metrics = self.agent.condenser.llm.metrics clean_metrics._accumulated_token_usage = copy.deepcopy(
metrics.accumulated_token_usage
# Create a new minimal metrics object with just what the frontend needs )
metrics = Metrics(model_name=agent_metrics.model_name)
# Set accumulated cost (sum of agent and condenser costs)
metrics.accumulated_cost = agent_metrics.accumulated_cost
if condenser_metrics:
metrics.accumulated_cost += condenser_metrics.accumulated_cost
# Add max_budget_per_task to metrics # Add max_budget_per_task to metrics
if self.state.budget_flag: if self.state.budget_flag:
metrics.max_budget_per_task = self.state.budget_flag.max_value clean_metrics.max_budget_per_task = self.state.budget_flag.max_value
# Set accumulated token usage (sum of agent and condenser token usage) action.llm_metrics = clean_metrics
# Use a deep copy to ensure we don't modify the original object
metrics._accumulated_token_usage = (
agent_metrics.accumulated_token_usage.model_copy(deep=True)
)
if condenser_metrics:
metrics._accumulated_token_usage = (
metrics._accumulated_token_usage
+ condenser_metrics.accumulated_token_usage
)
action.llm_metrics = metrics
# Log the metrics information for debugging # Log the metrics information for debugging
# Get the latest usage directly from the agent's metrics # Get the latest usage directly from the agent's metrics

View File

@@ -21,6 +21,7 @@ from openhands.events.action.agent import AgentFinishAction
from openhands.events.event import Event, EventSource from openhands.events.event import Event, EventSource
from openhands.llm.metrics import Metrics from openhands.llm.metrics import Metrics
from openhands.memory.view import View from openhands.memory.view import View
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.files import FileStore from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_agent_state_filename from openhands.storage.locations import get_conversation_agent_state_filename
@@ -84,6 +85,7 @@ class State:
limit_increase_amount=100, current_value=0, max_value=100 limit_increase_amount=100, current_value=0, max_value=100
) )
) )
convo_stats: ConversationStats | None = None
budget_flag: BudgetControlFlag | None = None budget_flag: BudgetControlFlag | None = None
confirmation_mode: bool = False confirmation_mode: bool = False
history: list[Event] = field(default_factory=list) history: list[Event] = field(default_factory=list)
@@ -91,8 +93,7 @@ class State:
outputs: dict = field(default_factory=dict) outputs: dict = field(default_factory=dict)
agent_state: AgentState = AgentState.LOADING agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None resume_state: AgentState | None = None
# global metrics for the current task
metrics: Metrics = field(default_factory=Metrics)
# root agent has level 0, and every delegate increases the level by one # root agent has level 0, and every delegate increases the level by one
delegate_level: int = 0 delegate_level: int = 0
# start_id and end_id track the range of events in history # start_id and end_id track the range of events in history
@@ -116,9 +117,14 @@ class State:
local_metrics: Metrics | None = None local_metrics: Metrics | None = None
delegates: dict[tuple[int, int], tuple[str, str]] | None = None delegates: dict[tuple[int, int], tuple[str, str]] | None = None
metrics: Metrics = field(default_factory=Metrics)
def save_to_session( def save_to_session(
self, sid: str, file_store: FileStore, user_id: str | None self, sid: str, file_store: FileStore, user_id: str | None
) -> None: ) -> None:
convo_stats = self.convo_stats
self.convo_stats = None # Don't save convo stats, handles itself
pickled = pickle.dumps(self) pickled = pickle.dumps(self)
logger.debug(f'Saving state to session {sid}:{self.agent_state}') logger.debug(f'Saving state to session {sid}:{self.agent_state}')
encoded = base64.b64encode(pickled).decode('utf-8') encoded = base64.b64encode(pickled).decode('utf-8')
@@ -138,6 +144,8 @@ class State:
logger.error(f'Failed to save state to session: {e}') logger.error(f'Failed to save state to session: {e}')
raise e raise e
self.convo_stats = convo_stats # restore reference
@staticmethod @staticmethod
def restore_from_session( def restore_from_session(
sid: str, file_store: FileStore, user_id: str | None = None sid: str, file_store: FileStore, user_id: str | None = None

View File

@@ -1,4 +1,3 @@
from openhands.controller.agent import Agent
from openhands.controller.state.control_flags import ( from openhands.controller.state.control_flags import (
BudgetControlFlag, BudgetControlFlag,
IterationControlFlag, IterationControlFlag,
@@ -14,7 +13,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.empty import NullObservation from openhands.events.observation.empty import NullObservation
from openhands.events.serialization.event import event_to_trajectory from openhands.events.serialization.event import event_to_trajectory
from openhands.events.stream import EventStream from openhands.events.stream import EventStream
from openhands.llm.metrics import Metrics from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.files import FileStore from openhands.storage.files import FileStore
@@ -51,8 +50,8 @@ class StateTracker:
def set_initial_state( def set_initial_state(
self, self,
id: str, id: str,
agent: Agent,
state: State | None, state: State | None,
convo_stats: ConversationStats,
max_iterations: int, max_iterations: int,
max_budget_per_task: float | None, max_budget_per_task: float | None,
confirmation_mode: bool = False, confirmation_mode: bool = False,
@@ -75,6 +74,7 @@ class StateTracker:
session_id=id.removesuffix('-delegate'), session_id=id.removesuffix('-delegate'),
user_id=self.user_id, user_id=self.user_id,
inputs={}, inputs={},
convo_stats=convo_stats,
iteration_flag=IterationControlFlag( iteration_flag=IterationControlFlag(
limit_increase_amount=max_iterations, limit_increase_amount=max_iterations,
current_value=0, current_value=0,
@@ -99,13 +99,7 @@ class StateTracker:
if self.state.start_id <= -1: if self.state.start_id <= -1:
self.state.start_id = 0 self.state.start_id = 0
logger.info( state.convo_stats = convo_stats
f'AgentController {id} initializing history from event {self.state.start_id}',
)
# Share the state metrics with the agent's LLM metrics
# This ensures that all accumulated metrics are always in sync between controller and llm
agent.llm.metrics = self.state.metrics
def _init_history(self, event_stream: EventStream) -> None: def _init_history(self, event_stream: EventStream) -> None:
"""Initializes the agent's history from the event stream. """Initializes the agent's history from the event stream.
@@ -254,6 +248,9 @@ class StateTracker:
if self.sid and self.file_store: if self.sid and self.file_store:
self.state.save_to_session(self.sid, self.file_store, self.user_id) self.state.save_to_session(self.sid, self.file_store, self.user_id)
if self.state.convo_stats:
self.state.convo_stats.save_metrics()
def run_control_flags(self): def run_control_flags(self):
"""Performs one step of the control flags""" """Performs one step of the control flags"""
self.state.iteration_flag.step() self.state.iteration_flag.step()
@@ -264,20 +261,8 @@ class StateTracker:
"""Ensures that budget flag is up to date with accumulated costs from llm completions """Ensures that budget flag is up to date with accumulated costs from llm completions
Budget flag will monitor for when budget is exceeded Budget flag will monitor for when budget is exceeded
""" """
if self.state.budget_flag: # Sync cost across all llm services from llm registry
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost if self.state.budget_flag and self.state.convo_stats:
self.state.budget_flag.current_value = (
def merge_metrics(self, metrics: Metrics): self.state.convo_stats.get_combined_metrics().accumulated_cost
"""Merges metrics with the state metrics )
NOTE: this should be refactored in the future. We should have services (draft llm, title autocomplete, condenser, etc)
use their own LLMs, but the metrics object should be shared. This way we have one source of truth for accumulated costs from
all services
This would prevent having fragmented stores for metrics, and we don't have the burden of deciding where and how to store them
if we decide introduce more specialized services that require llm completions
"""
self.state.metrics.merge(metrics)
if self.state.budget_flag:
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost

View File

@@ -157,13 +157,16 @@ class OpenHandsConfig(BaseModel):
"""Get a map of agent names to llm configs.""" """Get a map of agent names to llm configs."""
return {name: self.get_llm_config_from_agent(name) for name in self.agents} return {name: self.get_llm_config_from_agent(name) for name in self.agents}
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig: def get_llm_config_from_agent_config(self, agent_config: AgentConfig):
agent_config: AgentConfig = self.get_agent_config(name)
llm_config_name = ( llm_config_name = (
agent_config.llm_config if agent_config.llm_config is not None else 'llm' agent_config.llm_config if agent_config.llm_config is not None else 'llm'
) )
return self.get_llm_config(llm_config_name) return self.get_llm_config(llm_config_name)
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
agent_config: AgentConfig = self.get_agent_config(name)
return self.get_llm_config_from_agent_config(agent_config)
def get_agent_configs(self) -> dict[str, AgentConfig]: def get_agent_configs(self) -> dict[str, AgentConfig]:
return self.agents return self.agents

View File

@@ -6,7 +6,6 @@ from typing import Callable, Protocol
import openhands.agenthub # noqa F401 (we import this to get the agents registered) import openhands.agenthub # noqa F401 (we import this to get the agents registered)
import openhands.cli.suppress_warnings # noqa: F401 import openhands.cli.suppress_warnings # noqa: F401
from openhands.controller.agent import Agent
from openhands.controller.replay import ReplayManager from openhands.controller.replay import ReplayManager
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.core.config import ( from openhands.core.config import (
@@ -33,10 +32,12 @@ from openhands.events.action.action import Action
from openhands.events.event import Event from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation from openhands.events.observation import AgentStateChangedObservation
from openhands.io import read_input, read_task from openhands.io import read_input, read_task
from openhands.llm.llm_registry import LLMRegistry
from openhands.mcp import add_mcp_tools_to_agent from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync from openhands.utils.async_utils import call_async_from_sync
from openhands.utils.utils import create_registry_and_convo_stats
class FakeUserResponseFunc(Protocol): class FakeUserResponseFunc(Protocol):
@@ -53,12 +54,12 @@ async def run_controller(
initial_user_action: Action, initial_user_action: Action,
sid: str | None = None, sid: str | None = None,
runtime: Runtime | None = None, runtime: Runtime | None = None,
agent: Agent | None = None,
exit_on_message: bool = False, exit_on_message: bool = False,
fake_user_response_fn: FakeUserResponseFunc | None = None, fake_user_response_fn: FakeUserResponseFunc | None = None,
headless_mode: bool = True, headless_mode: bool = True,
memory: Memory | None = None, memory: Memory | None = None,
conversation_instructions: str | None = None, conversation_instructions: str | None = None,
llm_registry: LLMRegistry | None = None,
) -> State | None: ) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility. """Main coroutine to run the agent controller with task input flexibility.
@@ -70,7 +71,6 @@ async def run_controller(
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing. sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
Set it to incompatible value will cause unexpected behavior on RemoteRuntime. Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
runtime: (optional) A runtime for the agent to run on. runtime: (optional) A runtime for the agent to run on.
agent: (optional) A agent to run.
exit_on_message: quit if agent asks for a message from user (optional) exit_on_message: quit if agent asks for a message from user (optional)
fake_user_response_fn: An optional function that receives the current state fake_user_response_fn: An optional function that receives the current state
(could be None) and returns a fake user response. (could be None) and returns a fake user response.
@@ -98,8 +98,13 @@ async def run_controller(
""" """
sid = sid or generate_sid(config) sid = sid or generate_sid(config)
if agent is None: llm_registry, convo_stats, config = create_registry_and_convo_stats(
agent = create_agent(config) config,
sid,
None,
)
agent = create_agent(config, llm_registry)
# when the runtime is created, it will be connected and clone the selected repository # when the runtime is created, it will be connected and clone the selected repository
repo_directory = None repo_directory = None
@@ -108,6 +113,7 @@ async def run_controller(
repo_tokens = get_provider_tokens() repo_tokens = get_provider_tokens()
runtime = create_runtime( runtime = create_runtime(
config, config,
llm_registry,
sid=sid, sid=sid,
headless_mode=headless_mode, headless_mode=headless_mode,
agent=agent, agent=agent,
@@ -159,7 +165,7 @@ async def run_controller(
) )
controller, initial_state = create_controller( controller, initial_state = create_controller(
agent, runtime, config, replay_events=replay_events agent, runtime, config, convo_stats, replay_events=replay_events
) )
assert isinstance(initial_user_action, Action), ( assert isinstance(initial_user_action, Action), (

View File

@@ -21,12 +21,13 @@ from openhands.integrations.provider import (
ProviderToken, ProviderToken,
ProviderType, ProviderType,
) )
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.memory import Memory from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent from openhands.microagent.microagent import BaseMicroagent
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options from openhands.security import SecurityAnalyzer, options
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage import get_file_store from openhands.storage import get_file_store
from openhands.storage.data_models.user_secrets import UserSecrets from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
@@ -34,6 +35,7 @@ from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
def create_runtime( def create_runtime(
config: OpenHandsConfig, config: OpenHandsConfig,
llm_registry: LLMRegistry,
sid: str | None = None, sid: str | None = None,
headless_mode: bool = True, headless_mode: bool = True,
agent: Agent | None = None, agent: Agent | None = None,
@@ -82,6 +84,7 @@ def create_runtime(
sid=session_id, sid=session_id,
plugins=agent_cls.sandbox_plugins, plugins=agent_cls.sandbox_plugins,
headless_mode=headless_mode, headless_mode=headless_mode,
llm_registry=llm_registry,
git_provider_tokens=git_provider_tokens, git_provider_tokens=git_provider_tokens,
) )
@@ -203,16 +206,11 @@ def create_memory(
return memory return memory
def create_agent(config: OpenHandsConfig) -> Agent: def create_agent(config: OpenHandsConfig, llm_registry: LLMRegistry) -> Agent:
agent_cls: type[Agent] = Agent.get_cls(config.default_agent) agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
agent_config = config.get_agent_config(config.default_agent) agent_config = config.get_agent_config(config.default_agent)
llm_config = config.get_llm_config_from_agent(config.default_agent) config.get_llm_config_from_agent(config.default_agent)
agent = agent_cls(config=agent_config, llm_registry=llm_registry)
agent = agent_cls(
llm=LLM(config=llm_config),
config=agent_config,
)
return agent return agent
@@ -220,6 +218,7 @@ def create_controller(
agent: Agent, agent: Agent,
runtime: Runtime, runtime: Runtime,
config: OpenHandsConfig, config: OpenHandsConfig,
convo_stats: ConversationStats,
headless_mode: bool = True, headless_mode: bool = True,
replay_events: list[Event] | None = None, replay_events: list[Event] | None = None,
) -> tuple[AgentController, State | None]: ) -> tuple[AgentController, State | None]:
@@ -237,6 +236,7 @@ def create_controller(
controller = AgentController( controller = AgentController(
agent=agent, agent=agent,
convo_stats=convo_stats,
iteration_delta=config.max_iterations, iteration_delta=config.max_iterations,
budget_per_task_delta=config.max_budget_per_task, budget_per_task_delta=config.max_budget_per_task,
agent_to_llm_config=config.get_agent_to_llm_config_map(), agent_to_llm_config=config.get_agent_to_llm_config_map(),

View File

@@ -8,6 +8,7 @@ from typing import Any, Callable
import httpx import httpx
from openhands.core.config import LLMConfig from openhands.core.config import LLMConfig
from openhands.llm.metrics import Metrics
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') warnings.simplefilter('ignore')
@@ -34,7 +35,6 @@ from openhands.llm.fn_call_converter import (
convert_fncall_messages_to_non_fncall_messages, convert_fncall_messages_to_non_fncall_messages,
convert_non_fncall_messages_to_fncall_messages, convert_non_fncall_messages_to_fncall_messages,
) )
from openhands.llm.metrics import Metrics
from openhands.llm.retry_mixin import RetryMixin from openhands.llm.retry_mixin import RetryMixin
__all__ = ['LLM'] __all__ = ['LLM']
@@ -133,6 +133,7 @@ class LLM(RetryMixin, DebugMixin):
def __init__( def __init__(
self, self,
config: LLMConfig, config: LLMConfig,
service_id: str,
metrics: Metrics | None = None, metrics: Metrics | None = None,
retry_listener: Callable[[int, int], None] | None = None, retry_listener: Callable[[int, int], None] | None = None,
) -> None: ) -> None:
@@ -145,11 +146,12 @@ class LLM(RetryMixin, DebugMixin):
metrics: The metrics to use. metrics: The metrics to use.
""" """
self._tried_model_info = False self._tried_model_info = False
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.service_id = service_id
self.metrics: Metrics = ( self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model) metrics if metrics is not None else Metrics(model_name=config.model)
) )
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.model_info: ModelInfo | None = None self.model_info: ModelInfo | None = None
self.retry_listener = retry_listener self.retry_listener = retry_listener
@@ -408,8 +410,7 @@ class LLM(RetryMixin, DebugMixin):
assert self.config.log_completions_folder is not None assert self.config.log_completions_folder is not None
log_file = os.path.join( log_file = os.path.join(
self.config.log_completions_folder, self.config.log_completions_folder,
# use the metric model name (for draft editor) f'{self.config.model.replace("/", "__")}-{time.time()}.json',
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
) )
# set up the dict to be logged # set up the dict to be logged

View File

@@ -0,0 +1,132 @@
import copy
from typing import Any, Callable
from uuid import uuid4
from pydantic import BaseModel, ConfigDict
from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import LLM
class RegistryEvent(BaseModel):
llm: LLM
service_id: str
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
class LLMRegistry:
def __init__(
self,
config: OpenHandsConfig,
agent_cls: str | None = None,
retry_listener: Callable[[int, int], None] | None = None,
):
self.registry_id = str(uuid4())
self.config = copy.deepcopy(config)
self.retry_listner = retry_listener
self.agent_to_llm_config = self.config.get_agent_to_llm_config_map()
self.service_to_llm: dict[str, LLM] = {}
self.subscriber: Callable[[Any], None] | None = None
selected_agent_cls = self.config.default_agent
if agent_cls:
selected_agent_cls = agent_cls
agent_name = selected_agent_cls if selected_agent_cls is not None else 'agent'
llm_config = self.config.get_llm_config_from_agent(agent_name)
self.active_agent_llm: LLM = self.get_llm('agent', llm_config)
def _create_new_llm(
self, service_id: str, config: LLMConfig, with_listener: bool = True
) -> LLM:
if with_listener:
llm = LLM(
service_id=service_id, config=config, retry_listener=self.retry_listner
)
else:
llm = LLM(service_id=service_id, config=config)
self.service_to_llm[service_id] = llm
self.notify(RegistryEvent(llm=llm, service_id=service_id))
return llm
def request_extraneous_completion(
self, service_id: str, llm_config: LLMConfig, messages: list[dict[str, str]]
) -> str:
logger.info(f'extraneous completion: {service_id}')
if service_id not in self.service_to_llm:
self._create_new_llm(
config=llm_config, service_id=service_id, with_listener=False
)
llm = self.service_to_llm[service_id]
response = llm.completion(messages=messages)
return response.choices[0].message.content.strip()
def get_llm_from_agent_config(self, service_id: str, agent_config: AgentConfig):
llm_config = self.config.get_llm_config_from_agent_config(agent_config)
if service_id in self.service_to_llm:
if self.service_to_llm[service_id].config != llm_config:
# TODO: update llm config internally
# Done when agent delegates has different config, we should reuse the existing LLM
pass
return self.service_to_llm[service_id]
return self._create_new_llm(config=llm_config, service_id=service_id)
def get_llm(
self,
service_id: str,
config: LLMConfig | None = None,
):
logger.info(
f'[LLM registry {self.registry_id}]: Registering service for {service_id}'
)
# Attempting to switch configs for existing LLM
if (
service_id in self.service_to_llm
and self.service_to_llm[service_id].config != config
):
raise ValueError(
f'Requesting same service ID {service_id} with different config, use a new service ID'
)
if service_id in self.service_to_llm:
return self.service_to_llm[service_id]
if not config:
raise ValueError('Requesting new LLM without specifying LLM config')
return self._create_new_llm(config=config, service_id=service_id)
def get_active_llm(self) -> LLM:
return self.active_agent_llm
def _set_active_llm(self, service_id) -> None:
if service_id not in self.service_to_llm:
raise ValueError(f'Unrecognized service ID: {service_id}')
self.active_agent_llm = self.service_to_llm[service_id]
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
self.subscriber = callback
# Subscriptions happen after default llm is initialized
# Notify service of this llm
self.notify(
RegistryEvent(
llm=self.active_agent_llm, service_id=self.active_agent_llm.service_id
)
)
def notify(self, event: RegistryEvent):
if self.subscriber:
try:
self.subscriber(event)
except Exception as e:
logger.warning(f'Failed to emit event: {e}')

View File

@@ -10,6 +10,7 @@ from openhands.controller.state.state import State
from openhands.core.config.condenser_config import CondenserConfig from openhands.core.config.condenser_config import CondenserConfig
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import CondensationAction from openhands.events.action.agent import CondensationAction
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.view import View from openhands.memory.view import View
CONDENSER_METADATA_KEY = 'condenser_meta' CONDENSER_METADATA_KEY = 'condenser_meta'
@@ -144,7 +145,9 @@ class Condenser(ABC):
CONDENSER_REGISTRY[configuration_type] = cls CONDENSER_REGISTRY[configuration_type] = cls
@classmethod @classmethod
def from_config(cls, config: CondenserConfig) -> Condenser: def from_config(
cls, config: CondenserConfig, llm_registry: LLMRegistry
) -> Condenser:
"""Create a condenser from a configuration object. """Create a condenser from a configuration object.
Args: Args:
@@ -158,7 +161,7 @@ class Condenser(ABC):
""" """
try: try:
condenser_class = CONDENSER_REGISTRY[type(config)] condenser_class = CONDENSER_REGISTRY[type(config)]
return condenser_class.from_config(config) return condenser_class.from_config(config, llm_registry)
except KeyError: except KeyError:
raise ValueError(f'Unknown condenser config: {config}') raise ValueError(f'Unknown condenser config: {config}')

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
from openhands.events.action.agent import CondensationAction from openhands.events.action.agent import CondensationAction
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import ( from openhands.memory.condenser.condenser import (
Condensation, Condensation,
RollingCondenser, RollingCondenser,
@@ -58,7 +59,9 @@ class AmortizedForgettingCondenser(RollingCondenser):
@classmethod @classmethod
def from_config( def from_config(
cls, config: AmortizedForgettingCondenserConfig cls,
config: AmortizedForgettingCondenserConfig,
llm_registry: LLMRegistry,
) -> AmortizedForgettingCondenser: ) -> AmortizedForgettingCondenser:
return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'})) return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'}))

View File

@@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import BrowserOutputCondenserConfig
from openhands.events.event import Event from openhands.events.event import Event
from openhands.events.observation import BrowserOutputObservation from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.agent import AgentCondensationObservation
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View from openhands.memory.condenser.condenser import Condensation, Condenser, View
@@ -40,7 +41,7 @@ class BrowserOutputCondenser(Condenser):
@classmethod @classmethod
def from_config( def from_config(
cls, config: BrowserOutputCondenserConfig cls, config: BrowserOutputCondenserConfig, llm_registry: LLMRegistry
) -> BrowserOutputCondenser: ) -> BrowserOutputCondenser:
return BrowserOutputCondenser(**config.model_dump(exclude={'type'})) return BrowserOutputCondenser(**config.model_dump(exclude={'type'}))

View File

@@ -9,6 +9,7 @@ from openhands.events.action.agent import (
from openhands.events.action.message import MessageAction, SystemMessageAction from openhands.events.action.message import MessageAction, SystemMessageAction
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.events.observation import Observation from openhands.events.observation import Observation
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
@@ -177,7 +178,9 @@ class ConversationWindowCondenser(RollingCondenser):
@classmethod @classmethod
def from_config( def from_config(
cls, _config: ConversationWindowCondenserConfig cls,
_config: ConversationWindowCondenserConfig,
llm_registry: LLMRegistry,
) -> ConversationWindowCondenser: ) -> ConversationWindowCondenser:
return ConversationWindowCondenser() return ConversationWindowCondenser()

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
from openhands.events.action.agent import CondensationAction from openhands.events.action.agent import CondensationAction
from openhands.llm.llm import LLM from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import ( from openhands.memory.condenser.condenser import (
Condensation, Condensation,
RollingCondenser, RollingCondenser,
@@ -22,7 +23,12 @@ class ImportantEventSelection(BaseModel):
class LLMAttentionCondenser(RollingCondenser): class LLMAttentionCondenser(RollingCondenser):
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history.""" """Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1): def __init__(
self,
llm: LLM,
max_size: int = 100,
keep_first: int = 1,
):
if keep_first >= max_size // 2: if keep_first >= max_size // 2:
raise ValueError( raise ValueError(
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})' f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
@@ -113,15 +119,19 @@ class LLMAttentionCondenser(RollingCondenser):
return len(view) > self.max_size return len(view) > self.max_size
@classmethod @classmethod
def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser: def from_config(
cls, config: LLMAttentionCondenserConfig, llm_registry: LLMRegistry
) -> LLMAttentionCondenser:
# This condenser cannot take advantage of prompt caching. If it happens # This condenser cannot take advantage of prompt caching. If it happens
# to be set, we'll pay for the cache writes but never get a chance to # to be set, we'll pay for the cache writes but never get a chance to
# save on a read. # save on a read.
llm_config = config.llm_config.model_copy() llm_config = config.llm_config.model_copy()
llm_config.caching_prompt = False llm_config.caching_prompt = False
llm = llm_registry.get_llm('condenser', llm_config)
return LLMAttentionCondenser( return LLMAttentionCondenser(
llm=LLM(config=llm_config), llm=llm,
max_size=config.max_size, max_size=config.max_size,
keep_first=config.keep_first, keep_first=config.keep_first,
) )

View File

@@ -5,7 +5,8 @@ from openhands.core.message import Message, TextContent
from openhands.events.action.agent import CondensationAction from openhands.events.action.agent import CondensationAction
from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.agent import AgentCondensationObservation
from openhands.events.serialization.event import truncate_content from openhands.events.serialization.event import truncate_content
from openhands.llm import LLM from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import ( from openhands.memory.condenser.condenser import (
Condensation, Condensation,
RollingCondenser, RollingCondenser,
@@ -154,16 +155,17 @@ CURRENT_STATE: Last flip: Heads, Haiku count: 15/20"""
@classmethod @classmethod
def from_config( def from_config(
cls, config: LLMSummarizingCondenserConfig cls, config: LLMSummarizingCondenserConfig, llm_registry: LLMRegistry
) -> LLMSummarizingCondenser: ) -> LLMSummarizingCondenser:
# This condenser cannot take advantage of prompt caching. If it happens # This condenser cannot take advantage of prompt caching. If it happens
# to be set, we'll pay for the cache writes but never get a chance to # to be set, we'll pay for the cache writes but never get a chance to
# save on a read. # save on a read.
llm_config = config.llm_config.model_copy() llm_config = config.llm_config.model_copy()
llm_config.caching_prompt = False llm_config.caching_prompt = False
llm = llm_registry.get_llm('condenser', llm_config)
return LLMSummarizingCondenser( return LLMSummarizingCondenser(
llm=LLM(config=llm_config), llm=llm,
max_size=config.max_size, max_size=config.max_size,
keep_first=config.keep_first, keep_first=config.keep_first,
max_event_length=config.max_event_length, max_event_length=config.max_event_length,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from openhands.core.config.condenser_config import NoOpCondenserConfig from openhands.core.config.condenser_config import NoOpCondenserConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View from openhands.memory.condenser.condenser import Condensation, Condenser, View
@@ -12,7 +13,9 @@ class NoOpCondenser(Condenser):
return view return view
@classmethod @classmethod
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser: def from_config(
cls, config: NoOpCondenserConfig, llm_registry: LLMRegistry
) -> NoOpCondenser:
return NoOpCondenser() return NoOpCondenser()

View File

@@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import ObservationMaskingCondenserCo
from openhands.events.event import Event from openhands.events.event import Event
from openhands.events.observation import Observation from openhands.events.observation import Observation
from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.agent import AgentCondensationObservation
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View from openhands.memory.condenser.condenser import Condensation, Condenser, View
@@ -28,7 +29,9 @@ class ObservationMaskingCondenser(Condenser):
@classmethod @classmethod
def from_config( def from_config(
cls, config: ObservationMaskingCondenserConfig cls,
config: ObservationMaskingCondenserConfig,
llm_registry: LLMRegistry,
) -> ObservationMaskingCondenser: ) -> ObservationMaskingCondenser:
return ObservationMaskingCondenser(**config.model_dump(exclude={'type'})) return ObservationMaskingCondenser(**config.model_dump(exclude={'type'}))

View File

@@ -4,6 +4,7 @@ from contextlib import contextmanager
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.core.config.condenser_config import CondenserPipelineConfig from openhands.core.config.condenser_config import CondenserPipelineConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser from openhands.memory.condenser.condenser import Condensation, Condenser
from openhands.memory.view import View from openhands.memory.view import View
@@ -39,8 +40,10 @@ class CondenserPipeline(Condenser):
return result return result
@classmethod @classmethod
def from_config(cls, config: CondenserPipelineConfig) -> CondenserPipeline: def from_config(
condensers = [Condenser.from_config(c) for c in config.condensers] cls, config: CondenserPipelineConfig, llm_registry: LLMRegistry
) -> CondenserPipeline:
condensers = [Condenser.from_config(c, llm_registry) for c in config.condensers]
return CondenserPipeline(*condensers) return CondenserPipeline(*condensers)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from openhands.core.config.condenser_config import RecentEventsCondenserConfig from openhands.core.config.condenser_config import RecentEventsCondenserConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import Condensation, Condenser, View from openhands.memory.condenser.condenser import Condensation, Condenser, View
@@ -21,7 +22,9 @@ class RecentEventsCondenser(Condenser):
return View(events=head + tail) return View(events=head + tail)
@classmethod @classmethod
def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser: def from_config(
cls, config: RecentEventsCondenserConfig, llm_registry: LLMRegistry
) -> RecentEventsCondenser:
return RecentEventsCondenser(**config.model_dump(exclude={'type'})) return RecentEventsCondenser(**config.model_dump(exclude={'type'}))

View File

@@ -13,7 +13,8 @@ from openhands.core.message import Message, TextContent
from openhands.events.action.agent import CondensationAction from openhands.events.action.agent import CondensationAction
from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.agent import AgentCondensationObservation
from openhands.events.serialization.event import truncate_content from openhands.events.serialization.event import truncate_content
from openhands.llm import LLM from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser.condenser import ( from openhands.memory.condenser.condenser import (
Condensation, Condensation,
RollingCondenser, RollingCondenser,
@@ -180,15 +181,14 @@ class StructuredSummaryCondenser(RollingCondenser):
if max_size < 1: if max_size < 1:
raise ValueError(f'max_size ({max_size}) cannot be non-positive') raise ValueError(f'max_size ({max_size}) cannot be non-positive')
if not llm.is_function_calling_active():
raise ValueError(
'LLM must support function calling to use StructuredSummaryCondenser'
)
self.max_size = max_size self.max_size = max_size
self.keep_first = keep_first self.keep_first = keep_first
self.max_event_length = max_event_length self.max_event_length = max_event_length
self.llm = llm self.llm = llm
if not self.llm.is_function_calling_active():
raise ValueError(
'LLM must support function calling to use StructuredSummaryCondenser'
)
super().__init__() super().__init__()
@@ -309,16 +309,17 @@ Capture all relevant information, especially:
@classmethod @classmethod
def from_config( def from_config(
cls, config: StructuredSummaryCondenserConfig cls, config: StructuredSummaryCondenserConfig, llm_registry: LLMRegistry
) -> StructuredSummaryCondenser: ) -> StructuredSummaryCondenser:
# This condenser cannot take advantage of prompt caching. If it happens # This condenser cannot take advantage of prompt caching. If it happens
# to be set, we'll pay for the cache writes but never get a chance to # to be set, we'll pay for the cache writes but never get a chance to
# save on a read. # save on a read.
llm_config = config.llm_config.model_copy() llm_config = config.llm_config.model_copy()
llm_config.caching_prompt = False llm_config.caching_prompt = False
llm = llm_registry.get_llm('condenser', llm_config)
return StructuredSummaryCondenser( return StructuredSummaryCondenser(
llm=LLM(config=llm_config), llm=llm,
max_size=config.max_size, max_size=config.max_size,
keep_first=config.keep_first, keep_first=config.keep_first,
max_event_length=config.max_event_length, max_event_length=config.max_event_length,

View File

@@ -23,7 +23,7 @@ class ServiceContext:
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None): def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
self._strategy = strategy self._strategy = strategy
if llm_config is not None: if llm_config is not None:
self.llm = LLM(llm_config) self.llm = LLM(llm_config, service_id='resolver')
def set_strategy(self, strategy: IssueHandlerInterface) -> None: def set_strategy(self, strategy: IssueHandlerInterface) -> None:
self._strategy = strategy self._strategy = strategy

View File

@@ -28,6 +28,7 @@ from openhands.events.observation import (
) )
from openhands.events.stream import EventStreamSubscriber from openhands.events.stream import EventStreamSubscriber
from openhands.integrations.service_types import ProviderType from openhands.integrations.service_types import ProviderType
from openhands.llm.llm_registry import LLMRegistry
from openhands.resolver.interfaces.issue import Issue from openhands.resolver.interfaces.issue import Issue
from openhands.resolver.interfaces.issue_definitions import ( from openhands.resolver.interfaces.issue_definitions import (
ServiceContextIssue, ServiceContextIssue,
@@ -412,7 +413,8 @@ class IssueResolver:
shutil.rmtree(self.workspace_base) shutil.rmtree(self.workspace_base)
shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base) shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base)
runtime = create_runtime(self.app_config) llm_registry = LLMRegistry(self.app_config)
runtime = create_runtime(self.app_config, llm_registry)
await runtime.connect() await runtime.connect()
def on_event(evt: Event) -> None: def on_event(evt: Event) -> None:

View File

@@ -463,7 +463,7 @@ def update_existing_pull_request(
# Summarize with LLM if provided # Summarize with LLM if provided
if llm_config is not None: if llm_config is not None:
llm = LLM(llm_config) llm = LLM(llm_config, service_id='resolver')
with open( with open(
os.path.join( os.path.join(
os.path.dirname(__file__), os.path.dirname(__file__),

View File

@@ -54,6 +54,7 @@ from openhands.integrations.provider import (
ProviderType, ProviderType,
) )
from openhands.integrations.service_types import AuthenticationError from openhands.integrations.service_types import AuthenticationError
from openhands.llm.llm_registry import LLMRegistry
from openhands.microagent import ( from openhands.microagent import (
BaseMicroagent, BaseMicroagent,
load_microagents_from_dir, load_microagents_from_dir,
@@ -125,6 +126,7 @@ class Runtime(FileEditRuntimeMixin):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -178,7 +180,9 @@ class Runtime(FileEditRuntimeMixin):
# Load mixins # Load mixins
FileEditRuntimeMixin.__init__( FileEditRuntimeMixin.__init__(
self, enable_llm_editor=config.get_agent_config().enable_llm_editor self,
enable_llm_editor=config.get_agent_config().enable_llm_editor,
llm_registry=llm_registry,
) )
self.user_id = user_id self.user_id = user_id

View File

@@ -43,6 +43,7 @@ from openhands.events.observation import (
from openhands.events.serialization import event_to_dict, observation_from_dict from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.request import send_request from openhands.runtime.utils.request import send_request
@@ -68,6 +69,7 @@ class ActionExecutionClient(Runtime):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -85,6 +87,7 @@ class ActionExecutionClient(Runtime):
super().__init__( super().__init__(
config, config,
event_stream, event_stream,
llm_registry,
sid, sid,
plugins, plugins,
env_vars, env_vars,

View File

@@ -46,6 +46,7 @@ from openhands.events.observation import (
Observation, Observation,
) )
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
@@ -107,6 +108,7 @@ class CLIRuntime(Runtime):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -119,6 +121,7 @@ class CLIRuntime(Runtime):
super().__init__( super().__init__(
config, config,
event_stream, event_stream,
llm_registry,
sid, sid,
plugins, plugins,
env_vars, env_vars,

View File

@@ -20,6 +20,7 @@ from openhands.core.logger import DEBUG, DEBUG_RUNTIME
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream from openhands.events import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.builder import DockerRuntimeBuilder from openhands.runtime.builder import DockerRuntimeBuilder
from openhands.runtime.impl.action_execution.action_execution_client import ( from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
@@ -90,6 +91,7 @@ class DockerRuntime(ActionExecutionClient):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -143,6 +145,7 @@ class DockerRuntime(ActionExecutionClient):
super().__init__( super().__init__(
config, config,
event_stream, event_stream,
llm_registry,
sid, sid,
plugins, plugins,
env_vars, env_vars,

View File

@@ -43,6 +43,7 @@ from openhands.core.logger import DEBUG
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream from openhands.events import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.action_execution.action_execution_client import ( from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
) )
@@ -81,6 +82,7 @@ class KubernetesRuntime(ActionExecutionClient):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -137,6 +139,7 @@ class KubernetesRuntime(ActionExecutionClient):
super().__init__( super().__init__(
config, config,
event_stream, event_stream,
llm_registry,
sid, sid,
plugins, plugins,
env_vars, env_vars,

View File

@@ -26,6 +26,7 @@ from openhands.events.observation import (
) )
from openhands.events.serialization import event_to_dict, observation_from_dict from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.action_execution.action_execution_client import ( from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
) )
@@ -135,6 +136,7 @@ class LocalRuntime(ActionExecutionClient):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -186,6 +188,7 @@ class LocalRuntime(ActionExecutionClient):
super().__init__( super().__init__(
config, config,
event_stream, event_stream,
llm_registry,
sid, sid,
plugins, plugins,
env_vars, env_vars,
@@ -801,12 +804,6 @@ def _create_warm_server_in_background(
def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]: def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]:
from openhands.controller.agent import Agent from openhands.controller.agent import Agent
from openhands.llm.llm import LLM
agent_config = config.get_agent_config(config.default_agent) plugins = Agent.get_cls(config.default_agent).sandbox_plugins
llm = LLM(
config=config.get_llm_config_from_agent(config.default_agent),
)
agent = Agent.get_cls(config.default_agent)(llm, agent_config)
plugins = agent.sandbox_plugins
return plugins return plugins

View File

@@ -19,6 +19,7 @@ from openhands.core.exceptions import (
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream from openhands.events import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.builder.remote import RemoteRuntimeBuilder from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.impl.action_execution.action_execution_client import ( from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
@@ -51,6 +52,7 @@ class RemoteRuntime(ActionExecutionClient):
self, self,
config: OpenHandsConfig, config: OpenHandsConfig,
event_stream: EventStream, event_stream: EventStream,
llm_registry: LLMRegistry,
sid: str = 'default', sid: str = 'default',
plugins: list[PluginRequirement] | None = None, plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None, env_vars: dict[str, str] | None = None,
@@ -64,6 +66,7 @@ class RemoteRuntime(ActionExecutionClient):
super().__init__( super().__init__(
config, config,
event_stream, event_stream,
llm_registry,
sid, sid,
plugins, plugins,
env_vars, env_vars,

View File

@@ -23,7 +23,7 @@ from openhands.events.observation import (
) )
from openhands.linter import DefaultLinter from openhands.linter import DefaultLinter
from openhands.llm.llm import LLM from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics from openhands.llm.llm_registry import LLMRegistry
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
USER_MSG = """ USER_MSG = """
@@ -128,7 +128,13 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
# This restricts the number of lines we can edit to avoid exceeding the token limit. # This restricts the number of lines we can edit to avoid exceeding the token limit.
MAX_LINES_TO_EDIT = 300 MAX_LINES_TO_EDIT = 300
def __init__(self, enable_llm_editor: bool, *args: Any, **kwargs: Any) -> None: def __init__(
self,
enable_llm_editor: bool,
llm_registry: LLMRegistry,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.enable_llm_editor = enable_llm_editor self.enable_llm_editor = enable_llm_editor
@@ -138,7 +144,6 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
draft_editor_config = self.config.get_llm_config('draft_editor') draft_editor_config = self.config.get_llm_config('draft_editor')
# manually set the model name for the draft editor LLM to distinguish token costs # manually set the model name for the draft editor LLM to distinguish token costs
llm_metrics = Metrics(model_name='draft_editor:' + draft_editor_config.model)
if draft_editor_config.caching_prompt: if draft_editor_config.caching_prompt:
logger.debug( logger.debug(
'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. ' 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
@@ -146,7 +151,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
) )
draft_editor_config.caching_prompt = False draft_editor_config.caching_prompt = False
self.draft_editor_llm = LLM(draft_editor_config, metrics=llm_metrics) self.draft_editor_llm = llm_registry.get_llm(
'draft_editor_llm', draft_editor_config
)
logger.debug( logger.debug(
f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}' f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
) )

View File

@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
import socketio import socketio
from openhands.core.config import OpenHandsConfig from openhands.core.config import OpenHandsConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
from openhands.server.config.server_config import ServerConfig from openhands.server.config.server_config import ServerConfig
from openhands.server.data_models.agent_loop_info import AgentLoopInfo from openhands.server.data_models.agent_loop_info import AgentLoopInfo
@@ -136,6 +137,16 @@ class ConversationManager(ABC):
) -> list[AgentLoopInfo]: ) -> list[AgentLoopInfo]:
"""Get the AgentLoopInfo for conversations.""" """Get the AgentLoopInfo for conversations."""
@abstractmethod
async def request_llm_completion(
self,
sid: str,
service_id: str,
llm_config: LLMConfig,
messages: list[dict[str, str]],
) -> str:
"""Request extraneous llm completions for a conversation"""
@classmethod @classmethod
@abstractmethod @abstractmethod
def get_instance( def get_instance(

View File

@@ -15,13 +15,13 @@ from docker.models.containers import Container
from openhands.controller.agent import Agent from openhands.controller.agent import Agent
from openhands.core.config import OpenHandsConfig from openhands.core.config import OpenHandsConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
from openhands.events.nested_event_store import NestedEventStore from openhands.events.nested_event_store import NestedEventStore
from openhands.events.stream import EventStream from openhands.events.stream import EventStream
from openhands.experiments.experiment_manager import ExperimentManagerImpl from openhands.experiments.experiment_manager import ExperimentManagerImpl
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
from openhands.server.config.server_config import ServerConfig from openhands.server.config.server_config import ServerConfig
@@ -42,6 +42,7 @@ from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_dir from openhands.storage.locations import get_conversation_dir
from openhands.utils.async_utils import call_sync_from_async from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.import_utils import get_impl from openhands.utils.import_utils import get_impl
from openhands.utils.utils import create_registry_and_convo_stats
@dataclass @dataclass
@@ -275,6 +276,16 @@ class DockerNestedConversationManager(ConversationManager):
# Not supported - clients should connect directly to the nested server! # Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation') raise ValueError('unsupported_operation')
async def request_llm_completion(
self,
sid: str,
service_id: str,
llm_config: LLMConfig,
messages: list[dict[str, str]],
) -> str:
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
async def send_event_to_conversation(self, sid, data): async def send_event_to_conversation(self, sid, data):
async with httpx.AsyncClient( async with httpx.AsyncClient(
headers={ headers={
@@ -471,27 +482,27 @@ class DockerNestedConversationManager(ConversationManager):
# This session is created here only because it is the easiest way to get a runtime, which # This session is created here only because it is the easiest way to get a runtime, which
# is the easiest way to create the needed docker container # is the easiest way to create the needed docker container
# Run experiment manager variant test before creating session
config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test( config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test(
user_id, sid, self.config user_id, sid, self.config
) )
llm_registry, convo_stats, config = create_registry_and_convo_stats(
config, sid, user_id, settings
)
session = Session( session = Session(
sid=sid, sid=sid,
llm_registry=llm_registry,
convo_stats=convo_stats,
file_store=self.file_store, file_store=self.file_store,
config=config, config=config,
sio=self.sio, sio=self.sio,
user_id=user_id, user_id=user_id,
) )
llm_registry.retry_listner = session._notify_on_llm_retry
agent_cls = settings.agent or config.default_agent agent_cls = settings.agent or config.default_agent
agent_name = agent_cls if agent_cls is not None else 'agent'
llm = LLM(
config=config.get_llm_config_from_agent(agent_name),
retry_listener=session._notify_on_llm_retry,
)
llm = session._create_llm(agent_cls)
agent_config = config.get_agent_config(agent_cls) agent_config = config.get_agent_config(agent_cls)
agent = Agent.get_cls(agent_cls)(llm, agent_config) agent = Agent.get_cls(agent_cls)(agent_config, llm_registry)
config = config.model_copy(deep=True) config = config.model_copy(deep=True)
env_vars = config.sandbox.runtime_startup_env_vars env_vars = config.sandbox.runtime_startup_env_vars
@@ -543,6 +554,7 @@ class DockerNestedConversationManager(ConversationManager):
headless_mode=False, headless_mode=False,
attach_to_existing=False, attach_to_existing=False,
main_module='openhands.server', main_module='openhands.server',
llm_registry=llm_registry,
) )
# Hack - disable setting initial env. # Hack - disable setting initial env.

View File

@@ -6,12 +6,14 @@ from typing import Callable, Iterable
import socketio import socketio
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
from openhands.events.stream import EventStreamSubscriber, session_exists from openhands.events.stream import EventStreamSubscriber, session_exists
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
from openhands.server.config.server_config import ServerConfig from openhands.server.config.server_config import ServerConfig
from openhands.server.constants import ROOM_KEY from openhands.server.constants import ROOM_KEY
@@ -37,6 +39,7 @@ from openhands.utils.conversation_summary import (
) )
from openhands.utils.import_utils import get_impl from openhands.utils.import_utils import get_impl
from openhands.utils.shutdown_listener import should_continue from openhands.utils.shutdown_listener import should_continue
from openhands.utils.utils import create_registry_and_convo_stats
from .conversation_manager import ConversationManager from .conversation_manager import ConversationManager
@@ -332,12 +335,15 @@ class StandaloneConversationManager(ConversationManager):
) )
await self.close_session(oldest_conversation_id) await self.close_session(oldest_conversation_id)
config = self.config.model_copy(deep=True) llm_registry, convo_stats, config = create_registry_and_convo_stats(
self.config, sid, user_id, settings
)
session = Session( session = Session(
sid=sid, sid=sid,
file_store=self.file_store, file_store=self.file_store,
config=config, config=config,
llm_registry=llm_registry,
convo_stats=convo_stats,
sio=self.sio, sio=self.sio,
user_id=user_id, user_id=user_id,
) )
@@ -349,7 +355,9 @@ class StandaloneConversationManager(ConversationManager):
try: try:
session.agent_session.event_stream.subscribe( session.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, EventStreamSubscriber.SERVER,
self._create_conversation_update_callback(user_id, sid, settings), self._create_conversation_update_callback(
user_id, sid, settings, session.llm_registry
),
UPDATED_AT_CALLBACK_ID, UPDATED_AT_CALLBACK_ID,
) )
except ValueError: except ValueError:
@@ -369,6 +377,21 @@ class StandaloneConversationManager(ConversationManager):
raise RuntimeError(f'no_conversation:{sid}') raise RuntimeError(f'no_conversation:{sid}')
await session.dispatch(data) await session.dispatch(data)
async def request_llm_completion(
self,
sid: str,
service_id: str,
llm_config: LLMConfig,
messages: list[dict[str, str]],
):
session = self._local_agent_loops_by_sid.get(sid)
if not session:
raise RuntimeError(f'no_conversation:{sid}')
llm_registry = session.llm_registry
return llm_registry.request_extraneous_completion(
service_id, llm_config, messages
)
async def disconnect_from_session(self, connection_id: str): async def disconnect_from_session(self, connection_id: str):
sid = self._local_connection_id_to_session_id.pop(connection_id, None) sid = self._local_connection_id_to_session_id.pop(connection_id, None)
logger.info( logger.info(
@@ -450,6 +473,7 @@ class StandaloneConversationManager(ConversationManager):
user_id: str | None, user_id: str | None,
conversation_id: str, conversation_id: str,
settings: Settings, settings: Settings,
llm_registry: LLMRegistry,
) -> Callable: ) -> Callable:
def callback(event, *args, **kwargs): def callback(event, *args, **kwargs):
call_async_from_sync( call_async_from_sync(
@@ -458,6 +482,7 @@ class StandaloneConversationManager(ConversationManager):
user_id, user_id,
conversation_id, conversation_id,
settings, settings,
llm_registry,
event, event,
) )
@@ -468,6 +493,7 @@ class StandaloneConversationManager(ConversationManager):
user_id: str, user_id: str,
conversation_id: str, conversation_id: str,
settings: Settings, settings: Settings,
llm_registry: LLMRegistry,
event=None, event=None,
): ):
conversation_store = await self._get_conversation_store(user_id) conversation_store = await self._get_conversation_store(user_id)
@@ -495,7 +521,7 @@ class StandaloneConversationManager(ConversationManager):
conversation.title == default_title conversation.title == default_title
): # attempt to autogenerate if default title is in use ): # attempt to autogenerate if default title is in use
title = await auto_generate_title( title = await auto_generate_title(
conversation_id, user_id, self.file_store, settings conversation_id, user_id, self.file_store, settings, llm_registry
) )
if title and not title.isspace(): if title and not title.isspace():
conversation.title = title conversation.title = title

View File

@@ -33,7 +33,6 @@ from openhands.integrations.service_types import (
ProviderType, ProviderType,
SuggestedTask, SuggestedTask,
) )
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.data_models.agent_loop_info import AgentLoopInfo from openhands.server.data_models.agent_loop_info import AgentLoopInfo
@@ -47,6 +46,7 @@ from openhands.server.services.conversation_service import (
setup_init_conversation_settings, setup_init_conversation_settings,
) )
from openhands.server.shared import ( from openhands.server.shared import (
ConversationManagerImpl,
ConversationStoreImpl, ConversationStoreImpl,
config, config,
conversation_manager, conversation_manager,
@@ -364,7 +364,7 @@ async def get_prompt(
) )
prompt_template = generate_prompt_template(stringified_events) prompt_template = generate_prompt_template(stringified_events)
prompt = generate_prompt(llm_config, prompt_template) prompt = generate_prompt(llm_config, prompt_template, conversation_id)
return JSONResponse( return JSONResponse(
{ {
@@ -380,8 +380,9 @@ def generate_prompt_template(events: str) -> str:
return template.render(events=events) return template.render(events=events)
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str: def generate_prompt(
llm = LLM(llm_config) llm_config: LLMConfig, prompt_template: str, conversation_id: str
) -> str:
messages = [ messages = [
{ {
'role': 'system', 'role': 'system',
@@ -393,8 +394,9 @@ def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
}, },
] ]
response = llm.completion(messages=messages) raw_prompt = ConversationManagerImpl.request_llm_completion(
raw_prompt = response['choices'][0]['message']['content'].strip() 'remember_prompt', conversation_id, llm_config, messages
)
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL) prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
if prompt: if prompt:

View File

@@ -31,20 +31,60 @@ from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.utils.conversation_summary import get_default_conversation_title from openhands.utils.conversation_summary import get_default_conversation_title
async def create_new_conversation( async def initialize_conversation(
user_id: str | None,
conversation_id: str | None,
selected_repository: str | None,
selected_branch: str | None,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
git_provider: ProviderType | None = None,
) -> ConversationMetadata | None:
if conversation_id is None:
conversation_id = uuid.uuid4().hex
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
if not await conversation_store.exists(conversation_id):
logger.info(
f'New conversation ID: {conversation_id}',
extra={'user_id': user_id, 'session_id': conversation_id},
)
conversation_title = get_default_conversation_title(conversation_id)
logger.info(f'Saving metadata for conversation {conversation_id}')
convo_metadata = ConversationMetadata(
trigger=conversation_trigger,
conversation_id=conversation_id,
title=conversation_title,
user_id=user_id,
selected_repository=selected_repository,
selected_branch=selected_branch,
git_provider=git_provider,
)
await conversation_store.save_metadata(convo_metadata)
return convo_metadata
try:
convo_metadata = await conversation_store.get_metadata(conversation_id)
return convo_metadata
except Exception:
pass
return None
async def start_conversation(
user_id: str | None, user_id: str | None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None, git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None, custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None, initial_user_msg: str | None,
image_urls: list[str] | None, image_urls: list[str] | None,
replay_json: str | None, replay_json: str | None,
conversation_instructions: str | None = None, conversation_id: str,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI, convo_metadata: ConversationMetadata,
attach_conversation_id: bool = False, conversation_instructions: str | None,
git_provider: ProviderType | None = None,
conversation_id: str | None = None,
mcp_config: MCPConfig | None = None, mcp_config: MCPConfig | None = None,
) -> AgentLoopInfo: ) -> AgentLoopInfo:
logger.info( logger.info(
@@ -52,7 +92,7 @@ async def create_new_conversation(
extra={ extra={
'signal': 'create_conversation', 'signal': 'create_conversation',
'user_id': user_id, 'user_id': user_id,
'trigger': conversation_trigger.value, 'trigger': convo_metadata.trigger,
}, },
) )
logger.info('Loading settings') logger.info('Loading settings')
@@ -79,53 +119,25 @@ async def create_new_conversation(
raise MissingSettingsError('Settings not found') raise MissingSettingsError('Settings not found')
session_init_args['git_provider_tokens'] = git_provider_tokens session_init_args['git_provider_tokens'] = git_provider_tokens
session_init_args['selected_repository'] = selected_repository session_init_args['selected_repository'] = convo_metadata.selected_repository
session_init_args['custom_secrets'] = custom_secrets session_init_args['custom_secrets'] = custom_secrets
session_init_args['selected_branch'] = selected_branch session_init_args['selected_branch'] = convo_metadata.selected_branch
session_init_args['git_provider'] = git_provider session_init_args['git_provider'] = convo_metadata.git_provider
session_init_args['conversation_instructions'] = conversation_instructions session_init_args['conversation_instructions'] = conversation_instructions
if mcp_config: if mcp_config:
session_init_args['mcp_config'] = mcp_config session_init_args['mcp_config'] = mcp_config
conversation_init_data = ConversationInitData(**session_init_args) conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
logger.info('ServerConversation store loaded')
# For nested runtimes, we allow a single conversation id, passed in on container creation
if conversation_id is None:
conversation_id = uuid.uuid4().hex
if not await conversation_store.exists(conversation_id):
logger.info(
f'New conversation ID: {conversation_id}',
extra={'user_id': user_id, 'session_id': conversation_id},
)
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test( conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
user_id, conversation_id, conversation_init_data user_id, conversation_id, conversation_init_data
) )
conversation_title = get_default_conversation_title(conversation_id)
logger.info(f'Saving metadata for conversation {conversation_id}')
await conversation_store.save_metadata(
ConversationMetadata(
trigger=conversation_trigger,
conversation_id=conversation_id,
title=conversation_title,
user_id=user_id,
selected_repository=selected_repository,
selected_branch=selected_branch,
git_provider=git_provider,
llm_model=conversation_init_data.llm_model,
)
)
logger.info( logger.info(
f'Starting agent loop for conversation {conversation_id}', f'Starting agent loop for conversation {conversation_id}',
extra={'user_id': user_id, 'session_id': conversation_id}, extra={'user_id': user_id, 'session_id': conversation_id},
) )
initial_message_action = None initial_message_action = None
if initial_user_msg or image_urls: if initial_user_msg or image_urls:
initial_message_action = MessageAction( initial_message_action = MessageAction(
@@ -133,9 +145,6 @@ async def create_new_conversation(
image_urls=image_urls or [], image_urls=image_urls or [],
) )
if attach_conversation_id:
logger.warning('Attaching conversation ID is deprecated, skipping process')
agent_loop_info = await conversation_manager.maybe_start_agent_loop( agent_loop_info = await conversation_manager.maybe_start_agent_loop(
conversation_id, conversation_id,
conversation_init_data, conversation_init_data,
@@ -147,6 +156,47 @@ async def create_new_conversation(
return agent_loop_info return agent_loop_info
async def create_new_conversation(
user_id: str | None,
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
selected_repository: str | None,
selected_branch: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
replay_json: str | None,
conversation_instructions: str | None = None,
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
git_provider: ProviderType | None = None,
conversation_id: str | None = None,
mcp_config: MCPConfig | None = None,
) -> AgentLoopInfo:
conversation_metadata = await initialize_conversation(
user_id,
conversation_id,
selected_repository,
selected_branch,
conversation_trigger,
git_provider,
)
if not conversation_metadata:
raise Exception('Failed to initialize conversation')
return await start_conversation(
user_id,
git_provider_tokens,
custom_secrets,
initial_user_msg,
image_urls,
replay_json,
conversation_metadata.conversation_id,
conversation_metadata,
conversation_instructions,
mcp_config,
)
def create_provider_tokens_object( def create_provider_tokens_object(
providers_set: list[ProviderType], providers_set: list[ProviderType],
) -> PROVIDER_TOKEN_TYPE: ) -> PROVIDER_TOKEN_TYPE:

View File

@@ -0,0 +1,77 @@
import base64
import pickle
from threading import Lock
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm_registry import RegistryEvent
from openhands.llm.metrics import Metrics
from openhands.storage.files import FileStore
from openhands.storage.locations import get_conversation_stats_filename
class ConversationStats:
def __init__(
self,
file_store: FileStore | None,
conversation_id: str,
user_id: str | None,
):
self.metrics_path = get_conversation_stats_filename(conversation_id, user_id)
self.file_store = file_store
self.conversation_id = conversation_id
self.user_id = user_id
self._save_lock = Lock()
self.service_to_metrics: dict[str, Metrics] = {}
self.restored_metrics: dict[str, Metrics] = {}
# Always attempt to restore registry if it exists
self.maybe_restore_metrics()
def save_metrics(self):
if not self.file_store:
return
with self._save_lock:
pickled = pickle.dumps(self.service_to_metrics)
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
self.file_store.write(self.metrics_path, serialized_metrics)
def maybe_restore_metrics(self):
if not self.file_store or not self.conversation_id:
return
try:
encoded = self.file_store.read(self.metrics_path)
pickled = base64.b64decode(encoded)
self.restored_metrics = pickle.loads(pickled)
logger.info(f'restored metrics: {self.conversation_id}')
except FileNotFoundError:
pass
def get_combined_metrics(self) -> Metrics:
total_metrics = Metrics()
for metrics in self.service_to_metrics.values():
total_metrics.merge(metrics)
logger.info(f'metrics by all services: {self.service_to_metrics}')
logger.info(f'combined metrics\n\n{total_metrics}')
return total_metrics
def get_metrics_for_service(self, service_id: str) -> Metrics:
if service_id not in self.service_to_metrics:
raise Exception(f'LLM service does not exist {service_id}')
return self.service_to_metrics[service_id]
def register_llm(self, event: RegistryEvent):
# Listen for llm creations and track their metrics
llm = event.llm
service_id = event.service_id
if service_id in self.restored_metrics:
llm.metrics = self.restored_metrics[service_id].copy()
del self.restored_metrics[service_id]
self.service_to_metrics[service_id] = llm.metrics

View File

@@ -21,6 +21,7 @@ from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE, PROVIDER_TOKEN_TYPE,
ProviderHandler, ProviderHandler,
) )
from openhands.llm.llm_registry import LLMRegistry
from openhands.mcp import add_mcp_tools_to_agent from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent from openhands.microagent.microagent import BaseMicroagent
@@ -29,6 +30,7 @@ from openhands.runtime.base import Runtime
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
from openhands.security import SecurityAnalyzer, options from openhands.security import SecurityAnalyzer, options
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.data_models.user_secrets import UserSecrets from openhands.storage.data_models.user_secrets import UserSecrets
from openhands.storage.files import FileStore from openhands.storage.files import FileStore
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
@@ -48,6 +50,7 @@ class AgentSession:
sid: str sid: str
user_id: str | None user_id: str | None
event_stream: EventStream event_stream: EventStream
llm_registry: LLMRegistry
file_store: FileStore file_store: FileStore
controller: AgentController | None = None controller: AgentController | None = None
runtime: Runtime | None = None runtime: Runtime | None = None
@@ -63,6 +66,8 @@ class AgentSession:
self, self,
sid: str, sid: str,
file_store: FileStore, file_store: FileStore,
llm_registry: LLMRegistry,
convo_stats: ConversationStats,
status_callback: Callable | None = None, status_callback: Callable | None = None,
user_id: str | None = None, user_id: str | None = None,
) -> None: ) -> None:
@@ -80,6 +85,8 @@ class AgentSession:
self.logger = OpenHandsLoggerAdapter( self.logger = OpenHandsLoggerAdapter(
extra={'session_id': sid, 'user_id': user_id} extra={'session_id': sid, 'user_id': user_id}
) )
self.llm_registry = llm_registry
self.convo_stats = convo_stats
async def start( async def start(
self, self,
@@ -340,6 +347,7 @@ class AgentSession:
self.runtime = runtime_cls( self.runtime = runtime_cls(
config=config, config=config,
event_stream=self.event_stream, event_stream=self.event_stream,
llm_registry=self.llm_registry,
sid=self.sid, sid=self.sid,
plugins=agent.sandbox_plugins, plugins=agent.sandbox_plugins,
status_callback=self._status_callback, status_callback=self._status_callback,
@@ -360,6 +368,7 @@ class AgentSession:
self.runtime = runtime_cls( self.runtime = runtime_cls(
config=config, config=config,
event_stream=self.event_stream, event_stream=self.event_stream,
llm_registry=self.llm_registry,
sid=self.sid, sid=self.sid,
plugins=agent.sandbox_plugins, plugins=agent.sandbox_plugins,
status_callback=self._status_callback, status_callback=self._status_callback,
@@ -441,6 +450,7 @@ class AgentSession:
user_id=self.user_id, user_id=self.user_id,
file_store=self.file_store, file_store=self.file_store,
event_stream=self.event_stream, event_stream=self.event_stream,
convo_stats=self.convo_stats,
agent=agent, agent=agent,
iteration_delta=int(max_iterations), iteration_delta=int(max_iterations),
budget_per_task_delta=max_budget_per_task, budget_per_task_delta=max_budget_per_task,
@@ -490,6 +500,15 @@ class AgentSession:
) )
return memory return memory
def get_state(self) -> AgentState | None:
controller = self.controller
if controller:
return controller.state.agent_state
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
return AgentState.ERROR
return None
def _maybe_restore_state(self) -> State | None: def _maybe_restore_state(self) -> State | None:
"""Helper method to handle state restore logic.""" """Helper method to handle state restore logic."""
restored_state = None restored_state = None
@@ -510,14 +529,5 @@ class AgentSession:
self.logger.debug('No events found, no state to restore') self.logger.debug('No events found, no state to restore')
return restored_state return restored_state
def get_state(self) -> AgentState | None:
controller = self.controller
if controller:
return controller.state.agent_state
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
return AgentState.ERROR
return None
def is_closed(self) -> bool: def is_closed(self) -> bool:
return self._closed return self._closed

View File

@@ -2,6 +2,7 @@ import asyncio
from openhands.core.config import OpenHandsConfig from openhands.core.config import OpenHandsConfig
from openhands.events.stream import EventStream from openhands.events.stream import EventStream
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime import get_runtime_cls from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options from openhands.security import SecurityAnalyzer, options
@@ -45,6 +46,7 @@ class ServerConversation:
else: else:
runtime_cls = get_runtime_cls(self.config.runtime) runtime_cls = get_runtime_cls(self.config.runtime)
runtime = runtime_cls( runtime = runtime_cls(
llm_registry=LLMRegistry(self.config),
config=config, config=config,
event_stream=self.event_stream, event_stream=self.event_stream,
sid=self.sid, sid=self.sid,

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import time import time
from copy import deepcopy
from logging import LoggerAdapter from logging import LoggerAdapter
import socketio import socketio
@@ -28,9 +27,10 @@ from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.error import ErrorObservation from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.constants import ROOM_KEY from openhands.server.constants import ROOM_KEY
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession from openhands.server.session.agent_session import AgentSession
from openhands.server.session.conversation_init_data import ConversationInitData from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.storage.data_models.settings import Settings from openhands.storage.data_models.settings import Settings
@@ -45,6 +45,7 @@ class Session:
agent_session: AgentSession agent_session: AgentSession
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop
config: OpenHandsConfig config: OpenHandsConfig
llm_registry: LLMRegistry
file_store: FileStore file_store: FileStore
user_id: str | None user_id: str | None
logger: LoggerAdapter logger: LoggerAdapter
@@ -53,6 +54,8 @@ class Session:
self, self,
sid: str, sid: str,
config: OpenHandsConfig, config: OpenHandsConfig,
llm_registry: LLMRegistry,
convo_stats: ConversationStats,
file_store: FileStore, file_store: FileStore,
sio: socketio.AsyncServer | None, sio: socketio.AsyncServer | None,
user_id: str | None = None, user_id: str | None = None,
@@ -62,17 +65,21 @@ class Session:
self.last_active_ts = int(time.time()) self.last_active_ts = int(time.time())
self.file_store = file_store self.file_store = file_store
self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid}) self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid})
self.llm_registry = llm_registry
self.convo_stats = convo_stats
self.agent_session = AgentSession( self.agent_session = AgentSession(
sid, sid,
file_store, file_store,
llm_registry=self.llm_registry,
convo_stats=convo_stats,
status_callback=self.queue_status_message, status_callback=self.queue_status_message,
user_id=user_id, user_id=user_id,
) )
self.agent_session.event_stream.subscribe( self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid EventStreamSubscriber.SERVER, self.on_event, self.sid
) )
# Copying this means that when we update variables they are not applied to the shared global configuration! self.config = config
self.config = deepcopy(config)
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from openhands.experiments.experiment_manager import ExperimentManagerImpl from openhands.experiments.experiment_manager import ExperimentManagerImpl
@@ -140,13 +147,6 @@ class Session:
else self.config.max_budget_per_task else self.config.max_budget_per_task
) )
# This is a shallow copy of the default LLM config, so changes here will
# persist if we retrieve the default LLM config again when constructing
# the agent
default_llm_config = self.config.get_llm_config()
default_llm_config.model = settings.llm_model or ''
default_llm_config.api_key = settings.llm_api_key
default_llm_config.base_url = settings.llm_base_url
self.config.search_api_key = settings.search_api_key self.config.search_api_key = settings.search_api_key
if settings.sandbox_api_key: if settings.sandbox_api_key:
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value() self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
@@ -181,10 +181,9 @@ class Session:
) )
# TODO: override other LLM config & agent config groups (#2075) # TODO: override other LLM config & agent config groups (#2075)
llm = self._create_llm(agent_cls)
agent_config = self.config.get_agent_config(agent_cls) agent_config = self.config.get_agent_config(agent_cls)
agent_name = agent_cls if agent_cls is not None else 'agent'
llm_config = self.config.get_llm_config_from_agent(agent_name)
if settings.enable_default_condenser: if settings.enable_default_condenser:
# Default condenser chains three condensers together: # Default condenser chains three condensers together:
# 1. a conversation window condenser that handles explicit # 1. a conversation window condenser that handles explicit
@@ -200,7 +199,7 @@ class Session:
ConversationWindowCondenserConfig(), ConversationWindowCondenserConfig(),
BrowserOutputCondenserConfig(attention_window=2), BrowserOutputCondenserConfig(attention_window=2),
LLMSummarizingCondenserConfig( LLMSummarizingCondenserConfig(
llm_config=llm.config, keep_first=4, max_size=120 llm_config=llm_config, keep_first=4, max_size=120
), ),
] ]
) )
@@ -208,12 +207,14 @@ class Session:
self.logger.info( self.logger.info(
f'Enabling pipeline condenser with:' f'Enabling pipeline condenser with:'
f' browser_output_masking(attention_window=2), ' f' browser_output_masking(attention_window=2), '
f' llm(model="{llm.config.model}", ' f' llm(model="{llm_config.model}", '
f' base_url="{llm.config.base_url}", ' f' base_url="{llm_config.base_url}", '
f' keep_first=4, max_size=80)' f' keep_first=4, max_size=80)'
) )
agent_config.condenser = default_condenser_config agent_config.condenser = default_condenser_config
agent = Agent.get_cls(agent_cls)(llm, agent_config) agent = Agent.get_cls(agent_cls)(agent_config, self.llm_registry)
self.llm_registry.retry_listner = self._notify_on_llm_retry
git_provider_tokens = None git_provider_tokens = None
selected_repository = None selected_repository = None
@@ -269,14 +270,6 @@ class Session:
) )
return return
def _create_llm(self, agent_cls: str | None) -> LLM:
"""Initialize LLM, extracted for testing."""
agent_name = agent_cls if agent_cls is not None else 'agent'
return LLM(
config=self.config.get_llm_config_from_agent(agent_name),
retry_listener=self._notify_on_llm_retry,
)
def _notify_on_llm_retry(self, retries: int, max: int) -> None: def _notify_on_llm_retry(self, retries: int, max: int) -> None:
self.queue_status_message( self.queue_status_message(
'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}' 'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}'

View File

@@ -30,5 +30,13 @@ def get_conversation_agent_state_filename(sid: str, user_id: str | None = None)
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl' return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'
def get_conversation_llm_registry_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}llm_registry.json'
def get_conversation_stats_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}convo_stats.json'
def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str: def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str:
return f'{get_conversation_dir(sid, user_id)}exp_config.json' return f'{get_conversation_dir(sid, user_id)}exp_config.json'

View File

@@ -7,13 +7,16 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.events.event_store import EventStore from openhands.events.event_store import EventStore
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.storage.data_models.settings import Settings from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore from openhands.storage.files import FileStore
async def generate_conversation_title( async def generate_conversation_title(
message: str, llm_config: LLMConfig, max_length: int = 50 message: str,
llm_config: LLMConfig,
llm_registry: LLMRegistry,
max_length: int = 50,
) -> Optional[str]: ) -> Optional[str]:
"""Generate a concise title for a conversation based on the first user message. """Generate a concise title for a conversation based on the first user message.
@@ -35,8 +38,6 @@ async def generate_conversation_title(
truncated_message = message truncated_message = message
try: try:
llm = LLM(llm_config)
# Create a simple prompt for the LLM to generate a title # Create a simple prompt for the LLM to generate a title
messages = [ messages = [
{ {
@@ -49,8 +50,9 @@ async def generate_conversation_title(
}, },
] ]
response = llm.completion(messages=messages) title = llm_registry.request_extraneous_completion(
title = response.choices[0].message.content.strip() 'convo_title_creator', llm_config, messages
)
# Ensure the title isn't too long # Ensure the title isn't too long
if len(title) > max_length: if len(title) > max_length:
@@ -75,7 +77,11 @@ def get_default_conversation_title(conversation_id: str) -> str:
async def auto_generate_title( async def auto_generate_title(
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings conversation_id: str,
user_id: str | None,
file_store: FileStore,
settings: Settings,
llm_registry: LLMRegistry,
) -> str: ) -> str:
"""Auto-generate a title for a conversation based on the first user message. """Auto-generate a title for a conversation based on the first user message.
Uses LLM-based title generation if available, otherwise falls back to a simple truncation. Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
@@ -116,7 +122,7 @@ async def auto_generate_title(
# Try to generate title using LLM # Try to generate title using LLM
llm_title = await generate_conversation_title( llm_title = await generate_conversation_title(
first_user_message, llm_config first_user_message, llm_config, llm_registry
) )
if llm_title: if llm_title:
logger.info(f'Generated title using LLM: {llm_title}') logger.info(f'Generated title using LLM: {llm_title}')

37
openhands/utils/utils.py Normal file
View File

@@ -0,0 +1,37 @@
from copy import deepcopy
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage import get_file_store
from openhands.storage.data_models.settings import Settings
def setup_llm_config(config: OpenHandsConfig, settings: Settings) -> OpenHandsConfig:
# Copying this means that when we update variables they are not applied to the shared global configuration!
config = deepcopy(config)
llm_config = config.get_llm_config()
llm_config.model = settings.llm_model or ''
llm_config.api_key = settings.llm_api_key
llm_config.base_url = settings.llm_base_url
config.set_llm_config(llm_config)
return config
def create_registry_and_convo_stats(
config: OpenHandsConfig,
sid: str,
user_id: str | None,
user_settings: Settings | None = None,
) -> tuple[LLMRegistry, ConversationStats, OpenHandsConfig]:
user_config = config
if user_settings:
user_config = setup_llm_config(config, user_settings)
agent_cls = user_settings.agent if user_settings else None
llm_registry = LLMRegistry(user_config, agent_cls)
file_store = get_file_store(user_config.file_store, user_config.file_store_path)
convo_stats = ConversationStats(file_store, sid, user_id)
llm_registry.subscribe(convo_stats.register_llm)
return llm_registry, convo_stats, user_config

View File

@@ -1,3 +1,4 @@
[pytest] [pytest]
addopts = -p no:warnings addopts = -p no:warnings
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function asyncio_default_fixture_loop_scope = function

0
tests/__init__.py Normal file
View File

View File

@@ -10,6 +10,7 @@ from pytest import TempPathFactory
from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream from openhands.events import EventStream
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
@@ -268,9 +269,13 @@ def _load_runtime(
) )
event_stream = EventStream(sid, file_store) event_stream = EventStream(sid, file_store)
# Create a LLMRegistry instance for the runtime
llm_registry = LLMRegistry(config=OpenHandsConfig())
runtime = runtime_cls( runtime = runtime_cls(
config=config, config=config,
event_stream=event_stream, event_stream=event_stream,
llm_registry=llm_registry,
sid=sid, sid=sid,
plugins=plugins, plugins=plugins,
) )

0
tests/unit/__init__.py Normal file
View File

View File

@@ -20,7 +20,7 @@ def test_llm():
def _get_llm(type_: type[LLM]): def _get_llm(type_: type[LLM]):
with _patch_http(): with _patch_http():
return type_(config=config.get_llm_config()) return type_(config=config.get_llm_config(), service_id='test_service')
@pytest.fixture @pytest.fixture

View File

@@ -38,7 +38,7 @@ def default_config():
def test_llm_init_with_default_config(default_config): def test_llm_init_with_default_config(default_config):
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
assert llm.config.model == 'gpt-4o' assert llm.config.model == 'gpt-4o'
assert llm.config.api_key.get_secret_value() == 'test_key' assert llm.config.api_key.get_secret_value() == 'test_key'
assert isinstance(llm.metrics, Metrics) assert isinstance(llm.metrics, Metrics)
@@ -129,7 +129,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
'max_input_tokens': 8000, 'max_input_tokens': 8000,
'max_output_tokens': 2000, 'max_output_tokens': 2000,
} }
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
llm.init_model_info() llm.init_model_info()
assert llm.config.max_input_tokens == 8000 assert llm.config.max_input_tokens == 8000
assert llm.config.max_output_tokens == 2000 assert llm.config.max_output_tokens == 2000
@@ -138,7 +138,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
@patch('openhands.llm.llm.litellm.get_model_info') @patch('openhands.llm.llm.litellm.get_model_info')
def test_llm_init_without_model_info(mock_get_model_info, default_config): def test_llm_init_without_model_info(mock_get_model_info, default_config):
mock_get_model_info.side_effect = Exception('Model info not available') mock_get_model_info.side_effect = Exception('Model info not available')
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
llm.init_model_info() llm.init_model_info()
assert llm.config.max_input_tokens is None assert llm.config.max_input_tokens is None
assert llm.config.max_output_tokens is None assert llm.config.max_output_tokens is None
@@ -154,7 +154,7 @@ def test_llm_init_with_custom_config():
top_p=0.9, top_p=0.9,
top_k=None, top_k=None,
) )
llm = LLM(custom_config) llm = LLM(custom_config, service_id='test-service')
assert llm.config.model == 'custom-model' assert llm.config.model == 'custom-model'
assert llm.config.api_key.get_secret_value() == 'custom_key' assert llm.config.api_key.get_secret_value() == 'custom_key'
assert llm.config.max_input_tokens == 5000 assert llm.config.max_input_tokens == 5000
@@ -168,7 +168,7 @@ def test_llm_init_with_custom_config():
def test_llm_top_k_in_completion_when_set(mock_litellm_completion): def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
# Create a config with top_k set # Create a config with top_k set
config_with_top_k = LLMConfig(top_k=50) config_with_top_k = LLMConfig(top_k=50)
llm = LLM(config_with_top_k) llm = LLM(config_with_top_k, service_id='test-service')
# Define a side effect function to check top_k # Define a side effect function to check top_k
def side_effect(*args, **kwargs): def side_effect(*args, **kwargs):
@@ -186,7 +186,7 @@ def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion): def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
# Create a config with top_k set to None # Create a config with top_k set to None
config_without_top_k = LLMConfig(top_k=None) config_without_top_k = LLMConfig(top_k=None)
llm = LLM(config_without_top_k) llm = LLM(config_without_top_k, service_id='test-service')
# Define a side effect function to check top_k # Define a side effect function to check top_k
def side_effect(*args, **kwargs): def side_effect(*args, **kwargs):
@@ -202,7 +202,7 @@ def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
def test_llm_init_with_metrics(): def test_llm_init_with_metrics():
config = LLMConfig(model='gpt-4o', api_key='test_key') config = LLMConfig(model='gpt-4o', api_key='test_key')
metrics = Metrics() metrics = Metrics()
llm = LLM(config, metrics=metrics) llm = LLM(config, metrics=metrics, service_id='test-service')
assert llm.metrics is metrics assert llm.metrics is metrics
assert ( assert (
llm.metrics.model_name == 'default' llm.metrics.model_name == 'default'
@@ -224,7 +224,7 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion):
# Create LLM instance and make a completion call # Create LLM instance and make a completion call
config = LLMConfig(model='gpt-4o', api_key='test_key') config = LLMConfig(model='gpt-4o', api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}]) response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
# Verify the response latency was tracked correctly # Verify the response latency was tracked correctly
@@ -257,7 +257,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
'max_input_tokens': 7000, 'max_input_tokens': 7000,
'max_output_tokens': 1500, 'max_output_tokens': 1500,
} }
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
llm.init_model_info() llm.init_model_info()
assert llm.config.max_input_tokens == 7000 assert llm.config.max_input_tokens == 7000
assert llm.config.max_output_tokens == 1500 assert llm.config.max_output_tokens == 1500
@@ -280,7 +280,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
default_config.model = ( default_config.model = (
'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS 'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS
) )
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
llm.completion( llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
tools=[ tools=[
@@ -292,7 +292,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
# Test with Grok-4 model that doesn't support stop parameter # Test with Grok-4 model that doesn't support stop parameter
default_config.model = 'xai/grok-4-0709' default_config.model = 'xai/grok-4-0709'
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
llm.completion( llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
tools=[ tools=[
@@ -314,7 +314,7 @@ def test_completion_with_mocked_logger(
'choices': [{'message': {'content': 'Test response'}}] 'choices': [{'message': {'content': 'Test response'}}]
} }
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -345,7 +345,7 @@ def test_completion_retries(
{'choices': [{'message': {'content': 'Retry successful'}}]}, {'choices': [{'message': {'content': 'Retry successful'}}]},
] ]
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -365,7 +365,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
{'choices': [{'message': {'content': 'Retry successful'}}]}, {'choices': [{'message': {'content': 'Retry successful'}}]},
] ]
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -387,7 +387,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
def test_completion_operation_cancelled(mock_litellm_completion, default_config): def test_completion_operation_cancelled(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled') mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
with pytest.raises(OperationCancelled): with pytest.raises(OperationCancelled):
llm.completion( llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
@@ -404,7 +404,7 @@ def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = side_effect mock_litellm_completion.side_effect = side_effect
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
with pytest.raises(OperationCancelled): with pytest.raises(OperationCancelled):
try: try:
llm.completion( llm.completion(
@@ -428,7 +428,7 @@ def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_
mock_litellm_completion.side_effect = side_effect mock_litellm_completion.side_effect = side_effect
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
result = llm.completion( result = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -469,7 +469,7 @@ def test_completion_retry_with_llm_no_response_error_zero_temp(
mock_litellm_completion.side_effect = side_effect mock_litellm_completion.side_effect = side_effect
# Create LLM instance and make a completion call # Create LLM instance and make a completion call
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -509,7 +509,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp(
'LLM did not return a response' 'LLM did not return a response'
) )
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
with pytest.raises(LLMNoResponseError): with pytest.raises(LLMNoResponseError):
llm.completion( llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
@@ -575,7 +575,7 @@ def test_gemini_25_pro_function_calling(mock_httpx_get, mock_get_model_info):
for model_name, expected_support in test_cases: for model_name, expected_support in test_cases:
config = LLMConfig(model=model_name, api_key='test_key') config = LLMConfig(model=model_name, api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
assert llm.is_function_calling_active() == expected_support, ( assert llm.is_function_calling_active() == expected_support, (
f'Expected function calling support to be {expected_support} for model {model_name}' f'Expected function calling support to be {expected_support} for model {model_name}'
@@ -617,7 +617,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp_successful_ret
mock_litellm_completion.side_effect = side_effect mock_litellm_completion.side_effect = side_effect
# Create LLM instance and make a completion call with non-zero temperature # Create LLM instance and make a completion call with non-zero temperature
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -677,7 +677,7 @@ def test_completion_retry_with_llm_no_response_error_successful_retry(
mock_litellm_completion.side_effect = side_effect mock_litellm_completion.side_effect = side_effect
# Create LLM instance and make a completion call with explicit temperature=0 # Create LLM instance and make a completion call with explicit temperature=0
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -709,7 +709,7 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
} }
mock_litellm_completion.return_value = mock_response mock_litellm_completion.return_value = mock_response
test_llm = LLM(config=default_config) test_llm = LLM(config=default_config, service_id='test-service')
response = test_llm.completion( response = test_llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -743,7 +743,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
} }
# Initialize LLM and call completion # Initialize LLM and call completion
llm = LLM(config=gemini_config) llm = LLM(config=gemini_config, service_id='test-service')
llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}]) llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
# Verify that litellm_completion was called with the 'thinking' parameter # Verify that litellm_completion was called with the 'thinking' parameter
@@ -762,7 +762,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
@patch('openhands.llm.llm.litellm.token_counter') @patch('openhands.llm.llm.litellm.token_counter')
def test_get_token_count_with_dict_messages(mock_token_counter, default_config): def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
mock_token_counter.return_value = 42 mock_token_counter.return_value = 42
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Hello!'}] messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages) token_count = llm.get_token_count(messages)
@@ -777,7 +777,7 @@ def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
def test_get_token_count_with_message_objects( def test_get_token_count_with_message_objects(
mock_token_counter, default_config, mock_logger mock_token_counter, default_config, mock_logger
): ):
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
# Create a Message object and its equivalent dict # Create a Message object and its equivalent dict
message_obj = Message(role='user', content=[TextContent(text='Hello!')]) message_obj = Message(role='user', content=[TextContent(text='Hello!')])
@@ -806,7 +806,7 @@ def test_get_token_count_with_custom_tokenizer(
config = copy.deepcopy(default_config) config = copy.deepcopy(default_config)
config.custom_tokenizer = 'custom/tokenizer' config.custom_tokenizer = 'custom/tokenizer'
llm = LLM(config) llm = LLM(config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Hello!'}] messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages) token_count = llm.get_token_count(messages)
@@ -823,7 +823,7 @@ def test_get_token_count_error_handling(
mock_token_counter, default_config, mock_logger mock_token_counter, default_config, mock_logger
): ):
mock_token_counter.side_effect = Exception('Token counting failed') mock_token_counter.side_effect = Exception('Token counting failed')
llm = LLM(default_config) llm = LLM(default_config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Hello!'}] messages = [{'role': 'user', 'content': 'Hello!'}]
token_count = llm.get_token_count(messages) token_count = llm.get_token_count(messages)
@@ -865,7 +865,7 @@ def test_llm_token_usage(mock_litellm_completion, default_config):
# We'll make mock_litellm_completion return these responses in sequence # We'll make mock_litellm_completion return these responses in sequence
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2] mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
# First call # First call
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}]) llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
@@ -924,7 +924,7 @@ def test_accumulated_token_usage(mock_litellm_completion, default_config):
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2] mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
# Create LLM instance # Create LLM instance
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
# First call # First call
llm.completion(messages=[{'role': 'user', 'content': 'First message'}]) llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
@@ -980,7 +980,7 @@ def test_completion_with_log_completions(mock_litellm_completion, default_config
} }
mock_litellm_completion.return_value = mock_response mock_litellm_completion.return_value = mock_response
test_llm = LLM(config=default_config) test_llm = LLM(config=default_config, service_id='test-service')
response = test_llm.completion( response = test_llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -1006,7 +1006,7 @@ def test_llm_base_url_auto_protocol_patch(mock_get):
mock_get.return_value.status_code = 200 mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {'model': 'fake'} mock_get.return_value.json.return_value = {'model': 'fake'}
llm = LLM(config=config) llm = LLM(config=config, service_id='test-service')
llm.init_model_info() llm.init_model_info()
called_url = mock_get.call_args[0][0] called_url = mock_get.call_args[0][0]
@@ -1020,7 +1020,7 @@ def test_unknown_model_token_limits():
"""Test that models without known token limits get None for both max_output_tokens and max_input_tokens.""" """Test that models without known token limits get None for both max_output_tokens and max_input_tokens."""
# Create LLM instance with a non-existent model to avoid litellm having model info for it # Create LLM instance with a non-existent model to avoid litellm having model info for it
config = LLMConfig(model='non-existent-model', api_key='test_key') config = LLMConfig(model='non-existent-model', api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
# Verify max_output_tokens and max_input_tokens are initialized to None (default value) # Verify max_output_tokens and max_input_tokens are initialized to None (default value)
assert llm.config.max_output_tokens is None assert llm.config.max_output_tokens is None
@@ -1031,7 +1031,7 @@ def test_max_tokens_from_model_info():
"""Test that max_output_tokens and max_input_tokens are correctly initialized from model info.""" """Test that max_output_tokens and max_input_tokens are correctly initialized from model info."""
# Create LLM instance with GPT-4 model which has known token limits # Create LLM instance with GPT-4 model which has known token limits
config = LLMConfig(model='gpt-4', api_key='test_key') config = LLMConfig(model='gpt-4', api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
# GPT-4 has specific token limits # GPT-4 has specific token limits
# These are the expected values from litellm # These are the expected values from litellm
@@ -1043,7 +1043,7 @@ def test_claude_3_7_sonnet_max_output_tokens():
"""Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens.""" """Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens."""
# Create LLM instance with Claude 3.7 Sonnet model # Create LLM instance with Claude 3.7 Sonnet model
config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key') config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
# Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet # Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet
assert llm.config.max_output_tokens == 64000 assert llm.config.max_output_tokens == 64000
@@ -1055,7 +1055,7 @@ def test_claude_sonnet_4_max_output_tokens():
"""Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values.""" """Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values."""
# Create LLM instance with a Claude Sonnet 4 model # Create LLM instance with a Claude Sonnet 4 model
config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key') config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
# Verify max_output_tokens is set to the expected value # Verify max_output_tokens is set to the expected value
assert llm.config.max_output_tokens == 64000 assert llm.config.max_output_tokens == 64000
@@ -1068,7 +1068,7 @@ def test_sambanova_deepseek_model_max_output_tokens():
"""Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value.""" """Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value."""
# Create LLM instance with SambaNova DeepSeek model # Create LLM instance with SambaNova DeepSeek model
config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key') config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key')
llm = LLM(config) llm = LLM(config, service_id='test-service')
# SambaNova DeepSeek model has specific token limits # SambaNova DeepSeek model has specific token limits
# This is the expected value from litellm # This is the expected value from litellm
@@ -1081,7 +1081,7 @@ def test_max_output_tokens_override_in_config():
config = LLMConfig( config = LLMConfig(
model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048 model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048
) )
llm = LLM(config) llm = LLM(config, service_id='test-service')
# Verify the config has the overridden max_output_tokens value # Verify the config has the overridden max_output_tokens value
assert llm.config.max_output_tokens == 2048 assert llm.config.max_output_tokens == 2048
@@ -1098,7 +1098,7 @@ def test_azure_model_default_max_tokens():
) )
# Create LLM instance with Azure model # Create LLM instance with Azure model
llm = LLM(azure_config) llm = LLM(azure_config, service_id='test-service')
# Verify the config has the default max_output_tokens value # Verify the config has the default max_output_tokens value
assert llm.config.max_output_tokens is None # Default value assert llm.config.max_output_tokens is None # Default value
@@ -1143,7 +1143,7 @@ def test_gemini_none_reasoning_effort_uses_thinking_budget(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, 'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
} }
llm = LLM(config) llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages) llm.completion(messages=sample_messages)
@@ -1167,7 +1167,7 @@ def test_gemini_low_reasoning_effort_uses_thinking_budget(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, 'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
} }
llm = LLM(config) llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages) llm.completion(messages=sample_messages)
@@ -1191,7 +1191,7 @@ def test_gemini_medium_reasoning_effort_passes_through(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, 'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
} }
llm = LLM(config) llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages) llm.completion(messages=sample_messages)
@@ -1214,7 +1214,7 @@ def test_gemini_high_reasoning_effort_passes_through(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, 'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
} }
llm = LLM(config) llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages) llm.completion(messages=sample_messages)
@@ -1235,7 +1235,7 @@ def test_non_gemini_uses_reasoning_effort(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, 'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
} }
llm = LLM(config) llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages) llm.completion(messages=sample_messages)
@@ -1259,7 +1259,7 @@ def test_non_reasoning_model_no_optimization(mock_completion):
'usage': {'prompt_tokens': 10, 'completion_tokens': 5}, 'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
} }
llm = LLM(config) llm = LLM(config, service_id='test-service')
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}] sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
llm.completion(messages=sample_messages) llm.completion(messages=sample_messages)
@@ -1285,7 +1285,7 @@ def test_gemini_performance_optimization_end_to_end(mock_completion):
assert config.reasoning_effort is None assert config.reasoning_effort is None
# Create LLM and make completion # Create LLM and make completion
llm = LLM(config) llm = LLM(config, service_id='test-service')
messages = [{'role': 'user', 'content': 'Solve this complex problem'}] messages = [{'role': 'user', 'content': 'Solve this complex problem'}]
response = llm.completion(messages=messages) response = llm.completion(messages=messages)

View File

@@ -207,7 +207,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
), ),
] ]
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextIssue( handler = ServiceContextIssue(
GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
) )
@@ -251,7 +251,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
) )
# Initialize LLM and handler # Initialize LLM and handler
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextPR( handler = ServiceContextPR(
GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config
) )

View File

@@ -463,7 +463,7 @@ async def test_process_issue(
[], [],
) )
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
handler_instance.llm = LLM(llm_config) handler_instance.llm = LLM(llm_config, service_id='test-service')
# Mock the runtime and its methods # Mock the runtime and its methods
mock_runtime = MagicMock() mock_runtime = MagicMock()

View File

@@ -209,7 +209,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
), ),
] ]
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextIssue( handler = ServiceContextIssue(
GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
) )
@@ -253,7 +253,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
) )
# Initialize LLM and handler # Initialize LLM and handler
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
handler = ServiceContextPR( handler = ServiceContextPR(
GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config
) )

View File

@@ -500,7 +500,7 @@ async def test_process_issue(
[], [],
) )
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue' handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
handler_instance.llm = LLM(llm_config) handler_instance.llm = LLM(llm_config, service_id='test-service')
# Create mock runtime and mock run_controller # Create mock runtime and mock run_controller
mock_runtime = MagicMock() mock_runtime = MagicMock()

View File

@@ -18,6 +18,7 @@ from openhands.controller.state.control_flags import (
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.core.config import OpenHandsConfig from openhands.core.config import OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.main import run_controller from openhands.core.main import run_controller
from openhands.core.schema import AgentState from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
@@ -33,6 +34,7 @@ from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.empty import NullObservation from openhands.events.observation.empty import NullObservation
from openhands.events.serialization import event_to_dict from openhands.events.serialization import event_to_dict
from openhands.llm import LLM from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
from openhands.llm.metrics import Metrics, TokenUsage from openhands.llm.metrics import Metrics, TokenUsage
from openhands.memory.condenser.condenser import Condensation from openhands.memory.condenser.condenser import Condensation
from openhands.memory.condenser.impl.conversation_window_condenser import ( from openhands.memory.condenser.impl.conversation_window_condenser import (
@@ -45,6 +47,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
) )
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore from openhands.storage.memory import InMemoryFileStore
@@ -61,15 +64,43 @@ def event_loop():
@pytest.fixture @pytest.fixture
def mock_agent(): def mock_agent_with_stats():
agent = MagicMock(spec=Agent) """Create a mock agent with properly connected LLM registry and conversation stats."""
agent.llm = MagicMock(spec=LLM) import uuid
agent.llm.metrics = Metrics()
agent.llm.config = OpenHandsConfig().get_llm_config()
# Add config with enable_mcp attribute # Create LLM registry
agent.config = MagicMock(spec=AgentConfig) config = OpenHandsConfig()
agent.config.enable_mcp = True llm_registry = LLMRegistry(config=config)
# Create conversation stats
file_store = InMemoryFileStore({})
conversation_id = f'test-conversation-{uuid.uuid4()}'
conversation_stats = ConversationStats(
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
)
# Connect registry to stats (this is the key requirement)
llm_registry.subscribe(conversation_stats.register_llm)
# Create mock agent
agent = MagicMock(spec=Agent)
agent_config = MagicMock(spec=AgentConfig)
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
llm_registry.service_to_llm.clear()
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
agent.llm = mock_llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
# Add a proper system message mock # Add a proper system message mock
system_message = SystemMessageAction( system_message = SystemMessageAction(
@@ -79,7 +110,7 @@ def mock_agent():
system_message._id = -1 # Set invalid ID to avoid the ID check system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message agent.get_system_message.return_value = system_message
return agent return agent, conversation_stats, llm_registry
@pytest.fixture @pytest.fixture
@@ -134,10 +165,13 @@ async def send_event_to_controller(controller, event):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_agent_state(mock_agent, mock_event_stream): async def test_set_agent_state(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -152,10 +186,13 @@ async def test_set_agent_state(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_event_message_action(mock_agent, mock_event_stream): async def test_on_event_message_action(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -169,10 +206,15 @@ async def test_on_event_message_action(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream): async def test_on_event_change_agent_state_action(
mock_agent_with_stats, mock_event_stream
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -186,10 +228,17 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback): async def test_react_to_exception(
mock_agent_with_stats,
mock_event_stream,
mock_status_callback,
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
status_callback=mock_status_callback, status_callback=mock_status_callback,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
@@ -204,12 +253,17 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_react_to_content_policy_violation( async def test_react_to_content_policy_violation(
mock_agent, mock_event_stream, mock_status_callback mock_agent_with_stats,
mock_event_stream,
mock_status_callback,
): ):
"""Test that the controller properly handles content policy violations from the LLM.""" """Test that the controller properly handles content policy violations from the LLM."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
status_callback=mock_status_callback, status_callback=mock_status_callback,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
@@ -246,18 +300,16 @@ async def test_react_to_content_policy_violation(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_controller_with_fatal_error( async def test_run_controller_with_fatal_error(
test_event_stream, mock_memory, mock_agent test_event_stream, mock_memory, mock_agent_with_stats
): ):
config = OpenHandsConfig() config = OpenHandsConfig()
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
def agent_step_fn(state): def agent_step_fn(state):
print(f'agent_step_fn received state: {state}') print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls') return CmdRunAction(command='ls')
mock_agent.step = agent_step_fn mock_agent.step = agent_step_fn
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=ActionExecutionClient) runtime = MagicMock(spec=ActionExecutionClient)
@@ -284,14 +336,16 @@ async def test_run_controller_with_fatal_error(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
) )
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller( state = await run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='Test message'), initial_user_action=MessageAction(content='Test message'),
runtime=runtime, runtime=runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory, memory=mock_memory,
llm_registry=llm_registry,
) )
print(f'state: {state}') print(f'state: {state}')
events = list(test_event_stream.get_events()) events = list(test_event_stream.get_events())
@@ -312,18 +366,16 @@ async def test_run_controller_with_fatal_error(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_controller_stop_with_stuck( async def test_run_controller_stop_with_stuck(
test_event_stream, mock_memory, mock_agent test_event_stream, mock_memory, mock_agent_with_stats
): ):
config = OpenHandsConfig() config = OpenHandsConfig()
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
def agent_step_fn(state): def agent_step_fn(state):
print(f'agent_step_fn received state: {state}') print(f'agent_step_fn received state: {state}')
return CmdRunAction(command='ls') return CmdRunAction(command='ls')
mock_agent.step = agent_step_fn mock_agent.step = agent_step_fn
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
runtime = MagicMock(spec=ActionExecutionClient) runtime = MagicMock(spec=ActionExecutionClient)
@@ -352,14 +404,16 @@ async def test_run_controller_stop_with_stuck(
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
) )
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller( state = await run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='Test message'), initial_user_action=MessageAction(content='Test message'),
runtime=runtime, runtime=runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory, memory=mock_memory,
llm_registry=llm_registry,
) )
events = list(test_event_stream.get_events()) events = list(test_event_stream.get_events())
print(f'state: {state}') print(f'state: {state}')
@@ -391,11 +445,14 @@ async def test_run_controller_stop_with_stuck(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_max_iterations_extension(mock_agent, mock_event_stream): async def test_max_iterations_extension(mock_agent_with_stats, mock_event_stream):
# Test with headless_mode=False - should extend max_iterations # Test with headless_mode=False - should extend max_iterations
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -426,6 +483,7 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -450,7 +508,9 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_step_max_budget(mock_agent, mock_event_stream): async def test_step_max_budget(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Metrics are always synced with budget flag before # Metrics are always synced with budget flag before
metrics = Metrics() metrics = Metrics()
metrics.accumulated_cost = 10.1 metrics.accumulated_cost = 10.1
@@ -458,9 +518,13 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
limit_increase_amount=10, current_value=10.1, max_value=10 limit_increase_amount=10, current_value=10.1, max_value=10
) )
# Update agent's LLM metrics in place
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
budget_per_task_delta=10, budget_per_task_delta=10,
sid='test', sid='test',
@@ -475,7 +539,9 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_step_max_budget_headless(mock_agent, mock_event_stream): async def test_step_max_budget_headless(mock_agent_with_stats, mock_event_stream):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Metrics are always synced with budget flag before # Metrics are always synced with budget flag before
metrics = Metrics() metrics = Metrics()
metrics.accumulated_cost = 10.1 metrics.accumulated_cost = 10.1
@@ -483,9 +549,13 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
limit_increase_amount=10, current_value=10.1, max_value=10 limit_increase_amount=10, current_value=10.1, max_value=10
) )
# Update agent's LLM metrics in place
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
budget_per_task_delta=10, budget_per_task_delta=10,
sid='test', sid='test',
@@ -500,12 +570,14 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_budget_reset_on_continue(mock_agent, mock_event_stream): async def test_budget_reset_on_continue(mock_agent_with_stats, mock_event_stream):
"""Test that when a user continues after hitting the budget limit: """Test that when a user continues after hitting the budget limit:
1. Error is thrown when budget cap is exceeded 1. Error is thrown when budget cap is exceeded
2. LLM budget does not reset when user continues 2. LLM budget does not reset when user continues
3. Budget is extended by adding the initial budget cap to the current accumulated cost 3. Budget is extended by adding the initial budget cap to the current accumulated cost
""" """
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Create a real Metrics instance shared between controller state and llm # Create a real Metrics instance shared between controller state and llm
metrics = Metrics() metrics = Metrics()
metrics.accumulated_cost = 6.0 metrics.accumulated_cost = 6.0
@@ -521,10 +593,14 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
), ),
) )
# Update agent's LLM metrics in place
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
# Create controller with budget cap # Create controller with budget cap
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
budget_per_task_delta=initial_budget, budget_per_task_delta=initial_budget,
sid='test', sid='test',
@@ -570,11 +646,17 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream): async def test_reset_with_pending_action_no_observation(
mock_agent_with_stats, mock_event_stream
):
"""Test reset() when there's a pending action with tool call metadata but no observation.""" """Test reset() when there's a pending action with tool call metadata but no observation."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -617,11 +699,17 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_stream): async def test_reset_with_pending_action_stopped_state(
mock_agent_with_stats, mock_event_stream
):
"""Test reset() when there's a pending action and agent state is STOPPED.""" """Test reset() when there's a pending action and agent state is STOPPED."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -665,12 +753,16 @@ async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_st
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_with_pending_action_existing_observation( async def test_reset_with_pending_action_existing_observation(
mock_agent, mock_event_stream mock_agent_with_stats, mock_event_stream
): ):
"""Test reset() when there's a pending action with tool call metadata and an existing observation.""" """Test reset() when there's a pending action with tool call metadata and an existing observation."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -708,11 +800,15 @@ async def test_reset_with_pending_action_existing_observation(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_without_pending_action(mock_agent, mock_event_stream): async def test_reset_without_pending_action(mock_agent_with_stats, mock_event_stream):
"""Test reset() when there's no pending action.""" """Test reset() when there's no pending action."""
# Connect LLM registry to conversation stats
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -738,12 +834,15 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_with_pending_action_no_metadata( async def test_reset_with_pending_action_no_metadata(
mock_agent, mock_event_stream, monkeypatch mock_agent_with_stats, mock_event_stream, monkeypatch
): ):
"""Test reset() when there's a pending action without tool call metadata.""" """Test reset() when there's a pending action without tool call metadata."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -782,16 +881,13 @@ async def test_reset_with_pending_action_no_metadata(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_controller_max_iterations_has_metrics( async def test_run_controller_max_iterations_has_metrics(
test_event_stream, mock_memory, mock_agent test_event_stream, mock_memory, mock_agent_with_stats
): ):
config = OpenHandsConfig( config = OpenHandsConfig(
max_iterations=3, max_iterations=3,
) )
event_stream = test_event_stream event_stream = test_event_stream
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = config.get_llm_config()
step_count = 0 step_count = 0
@@ -833,14 +929,16 @@ async def test_run_controller_max_iterations_has_metrics(
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())) event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller( state = await run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='Test message'), initial_user_action=MessageAction(content='Test message'),
runtime=runtime, runtime=runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory, memory=mock_memory,
llm_registry=llm_registry,
) )
state.metrics = mock_agent.llm.metrics state.metrics = mock_agent.llm.metrics
@@ -867,10 +965,17 @@ async def test_run_controller_max_iterations_has_metrics(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback): async def test_notify_on_llm_retry(
mock_agent_with_stats,
mock_event_stream,
mock_status_callback,
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
status_callback=mock_status_callback, status_callback=mock_status_callback,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
@@ -908,9 +1013,15 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
], ],
) )
async def test_context_window_exceeded_error_handling( async def test_context_window_exceeded_error_handling(
context_window_error, mock_agent, mock_runtime, test_event_stream, mock_memory context_window_error,
mock_agent_with_stats,
mock_runtime,
test_event_stream,
mock_memory,
): ):
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact.""" """Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
max_iterations = 5 max_iterations = 5
error_after = 2 error_after = 2
@@ -973,15 +1084,17 @@ async def test_context_window_exceeded_error_handling(
# state is set to error out before then, if this terminates and we have a # state is set to error out before then, if this terminates and we have a
# record of the error being thrown we can be confident that the controller # record of the error being thrown we can be confident that the controller
# handles the truncation correctly. # handles the truncation correctly.
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
final_state = await asyncio.wait_for( final_state = await asyncio.wait_for(
run_controller( run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='INITIAL'), initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime, runtime=mock_runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory, memory=mock_memory,
llm_registry=llm_registry,
), ),
timeout=10, timeout=10,
) )
@@ -1072,9 +1185,13 @@ async def test_context_window_exceeded_error_handling(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_with_truncation( async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_agent, mock_runtime, mock_memory, test_event_stream mock_agent_with_stats,
mock_runtime,
mock_memory,
test_event_stream,
): ):
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON.""" """Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
class StepState: class StepState:
def __init__(self): def __init__(self):
@@ -1121,15 +1238,17 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
mock_runtime.config = copy.deepcopy(config) mock_runtime.config = copy.deepcopy(config)
try: try:
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await asyncio.wait_for( state = await asyncio.wait_for(
run_controller( run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='INITIAL'), initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime, runtime=mock_runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory, memory=mock_memory,
llm_registry=llm_registry,
), ),
timeout=10, timeout=10,
) )
@@ -1156,9 +1275,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded_without_truncation( async def test_run_controller_with_context_window_exceeded_without_truncation(
mock_agent, mock_runtime, mock_memory, test_event_stream mock_agent_with_stats,
mock_runtime,
mock_memory,
test_event_stream,
): ):
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON.""" """Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
class StepState: class StepState:
def __init__(self): def __init__(self):
@@ -1199,15 +1322,17 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
config = OpenHandsConfig(max_iterations=3) config = OpenHandsConfig(max_iterations=3)
mock_runtime.config = copy.deepcopy(config) mock_runtime.config = copy.deepcopy(config)
try: try:
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await asyncio.wait_for( state = await asyncio.wait_for(
run_controller( run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='INITIAL'), initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime, runtime=mock_runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=mock_memory, memory=mock_memory,
llm_registry=llm_registry,
), ),
timeout=10, timeout=10,
) )
@@ -1244,7 +1369,11 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_controller_with_memory_error(test_event_stream, mock_agent): async def test_run_controller_with_memory_error(
test_event_stream, mock_agent_with_stats
):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
config = OpenHandsConfig() config = OpenHandsConfig()
event_stream = test_event_stream event_stream = test_event_stream
@@ -1273,14 +1402,16 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
with patch.object( with patch.object(
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
): ):
# Mock the create_agent function to return our mock agent
with patch('openhands.core.main.create_agent', return_value=mock_agent):
state = await run_controller( state = await run_controller(
config=config, config=config,
initial_user_action=MessageAction(content='Test message'), initial_user_action=MessageAction(content='Test message'),
runtime=runtime, runtime=runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=memory, memory=memory,
llm_registry=llm_registry,
) )
assert state.iteration_flag.current_value == 0 assert state.iteration_flag.current_value == 0
@@ -1289,7 +1420,9 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_action_metrics_copy(mock_agent): async def test_action_metrics_copy(mock_agent_with_stats):
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Setup # Setup
file_store = InMemoryFileStore({}) file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store) event_stream = EventStream(sid='test', file_store=file_store)
@@ -1299,8 +1432,7 @@ async def test_action_metrics_copy(mock_agent):
initial_state = State(metrics=metrics, budget_flag=None) initial_state = State(metrics=metrics, budget_flag=None)
# Create agent with metrics # Update agent's LLM metrics
mock_agent.llm = MagicMock(spec=LLM)
# Add multiple token usages - we should get the last one in the action # Add multiple token usages - we should get the last one in the action
usage1 = TokenUsage( usage1 = TokenUsage(
@@ -1342,6 +1474,11 @@ async def test_action_metrics_copy(mock_agent):
mock_agent.llm.metrics = metrics mock_agent.llm.metrics = metrics
# Register the metrics with the LLM registry
llm_registry.service_to_llm['agent'] = mock_agent.llm
# Manually notify the conversation stats about the LLM registration
llm_registry.notify(RegistryEvent(llm=mock_agent.llm, service_id='agent'))
# Mock agent step to return an action # Mock agent step to return an action
action = MessageAction(content='Test message') action = MessageAction(content='Test message')
@@ -1354,6 +1491,7 @@ async def test_action_metrics_copy(mock_agent):
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=event_stream, event_stream=event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -1411,12 +1549,13 @@ async def test_action_metrics_copy(mock_agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_condenser_metrics_included(mock_agent, test_event_stream): async def test_condenser_metrics_included(mock_agent_with_stats, test_event_stream):
"""Test that metrics from the condenser's LLM are included in the action metrics.""" """Test that metrics from the condenser's LLM are included in the action metrics."""
# Set up agent metrics mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
agent_metrics = Metrics(model_name='agent-model')
agent_metrics.accumulated_cost = 0.05 # Set up agent metrics in place
agent_metrics._accumulated_token_usage = TokenUsage( mock_agent.llm.metrics.accumulated_cost = 0.05
mock_agent.llm.metrics._accumulated_token_usage = TokenUsage(
model='agent-model', model='agent-model',
prompt_tokens=100, prompt_tokens=100,
completion_tokens=50, completion_tokens=50,
@@ -1424,7 +1563,6 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
cache_write_tokens=10, cache_write_tokens=10,
response_id='agent-accumulated', response_id='agent-accumulated',
) )
# mock_agent.llm.metrics = agent_metrics
mock_agent.name = 'TestAgent' mock_agent.name = 'TestAgent'
# Create condenser with its own metrics # Create condenser with its own metrics
@@ -1442,6 +1580,11 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
) )
condenser.llm.metrics = condenser_metrics condenser.llm.metrics = condenser_metrics
# Register the condenser metrics with the LLM registry
llm_registry.service_to_llm['condenser'] = condenser.llm
# Manually notify the conversation stats about the condenser LLM registration
llm_registry.notify(RegistryEvent(llm=condenser.llm, service_id='condenser'))
# Attach the condenser to the mock_agent # Attach the condenser to the mock_agent
mock_agent.condenser = condenser mock_agent.condenser = condenser
@@ -1463,11 +1606,12 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=test_event_stream, event_stream=test_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
headless_mode=True, headless_mode=True,
initial_state=State(metrics=agent_metrics, budget_flag=None), initial_state=State(metrics=mock_agent.llm.metrics, budget_flag=None),
) )
# Execute one step # Execute one step
@@ -1505,7 +1649,9 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent): async def test_first_user_message_with_identical_content(
test_event_stream, mock_agent_with_stats
):
"""Test that _first_user_message correctly identifies the first user message. """Test that _first_user_message correctly identifies the first user message.
This test verifies that messages with identical content but different IDs are properly This test verifies that messages with identical content but different IDs are properly
@@ -1514,14 +1660,12 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
The issue we're checking is that the comparison (action == self._first_user_message()) The issue we're checking is that the comparison (action == self._first_user_message())
should correctly differentiate between messages with the same content but different IDs. should correctly differentiate between messages with the same content but different IDs.
""" """
# Create an agent controller mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=test_event_stream, event_stream=test_event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -1569,11 +1713,15 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_controller_processes_null_observation_with_cause(): async def test_agent_controller_processes_null_observation_with_cause(
mock_agent_with_stats,
):
"""Test that AgentController processes NullObservation events with a cause value. """Test that AgentController processes NullObservation events with a cause value.
And that the agent's step method is called as a result. And that the agent's step method is called as a result.
""" """
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Create an in-memory file store and real event stream # Create an in-memory file store and real event stream
file_store = InMemoryFileStore() file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store) event_stream = EventStream(sid='test-session', file_store=file_store)
@@ -1581,19 +1729,11 @@ async def test_agent_controller_processes_null_observation_with_cause():
# Create a Memory instance - not used directly in this test but needed for setup # Create a Memory instance - not used directly in this test but needed for setup
Memory(event_stream=event_stream, sid='test-session') Memory(event_stream=event_stream, sid='test-session')
# Create a mock agent with necessary attributes
mock_agent = MagicMock(spec=Agent)
mock_agent.get_system_message = MagicMock(
return_value=None,
)
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
# Create a controller with the mock agent # Create a controller with the mock agent
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=event_stream, event_stream=event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test-session', sid='test-session',
) )
@@ -1655,8 +1795,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
) )
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent): def test_agent_controller_should_step_with_null_observation_cause_zero(
mock_agent_with_stats,
):
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0.""" """Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
# Create a mock event stream # Create a mock event stream
file_store = InMemoryFileStore() file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store) event_stream = EventStream(sid='test-session', file_store=file_store)
@@ -1665,6 +1809,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
controller = AgentController( controller = AgentController(
agent=mock_agent, agent=mock_agent,
event_stream=event_stream, event_stream=event_stream,
convo_stats=conversation_stats,
iteration_delta=10, iteration_delta=10,
sid='test-session', sid='test-session',
) )
@@ -1683,10 +1828,15 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
) )
def test_system_message_in_event_stream(mock_agent, test_event_stream): def test_system_message_in_event_stream(mock_agent_with_stats, test_event_stream):
"""Test that SystemMessageAction is added to event stream in AgentController.""" """Test that SystemMessageAction is added to event stream in AgentController."""
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
_ = AgentController( _ = AgentController(
agent=mock_agent, event_stream=test_event_stream, iteration_delta=10 agent=mock_agent,
event_stream=test_event_stream,
convo_stats=conversation_stats,
iteration_delta=10,
) )
# Get events from the event stream # Get events from the event stream

View File

@@ -12,8 +12,9 @@ from openhands.controller.state.control_flags import (
IterationControlFlag, IterationControlFlag,
) )
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.core.config import LLMConfig from openhands.core.config import OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig from openhands.core.config.agent_config import AgentConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.core.schema import AgentState from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream from openhands.events import EventSource, EventStream
from openhands.events.action import ( from openhands.events.action import (
@@ -28,11 +29,39 @@ from openhands.events.event import Event, RecallType
from openhands.events.observation.agent import RecallObservation from openhands.events.observation.agent import RecallObservation
from openhands.events.stream import EventStreamSubscriber from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory from openhands.memory.memory import Memory
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def llm_registry():
config = OpenHandsConfig()
return LLMRegistry(config=config)
@pytest.fixture
def conversation_stats():
import uuid
file_store = InMemoryFileStore({})
# Use a unique conversation ID for each test to avoid conflicts
conversation_id = f'test-conversation-{uuid.uuid4()}'
return ConversationStats(
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
)
@pytest.fixture
def connected_registry_and_stats(llm_registry, conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly"""
# Subscribe to LLM registry events to track metrics
llm_registry.subscribe(conversation_stats.register_llm)
return llm_registry, conversation_stats
@pytest.fixture @pytest.fixture
def mock_event_stream(): def mock_event_stream():
"""Creates an event stream in memory.""" """Creates an event stream in memory."""
@@ -42,15 +71,17 @@ def mock_event_stream():
@pytest.fixture @pytest.fixture
def mock_parent_agent(): def mock_parent_agent(llm_registry):
"""Creates a mock parent agent for testing delegation.""" """Creates a mock parent agent for testing delegation."""
agent = MagicMock(spec=Agent) agent = MagicMock(spec=Agent)
agent.name = 'ParentAgent' agent.name = 'ParentAgent'
agent.llm = MagicMock(spec=LLM) agent.llm = MagicMock(spec=LLM)
agent.llm.service_id = 'main_agent'
agent.llm.metrics = Metrics() agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig() agent.llm.config = LLMConfig()
agent.llm.retry_listener = None # Add retry_listener attribute agent.llm.retry_listener = None # Add retry_listener attribute
agent.config = AgentConfig() agent.config = AgentConfig()
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
# Add a proper system message mock # Add a proper system message mock
system_message = SystemMessageAction(content='Test system message') system_message = SystemMessageAction(content='Test system message')
@@ -61,15 +92,17 @@ def mock_parent_agent():
@pytest.fixture @pytest.fixture
def mock_child_agent(): def mock_child_agent(llm_registry):
"""Creates a mock child agent for testing delegation.""" """Creates a mock child agent for testing delegation."""
agent = MagicMock(spec=Agent) agent = MagicMock(spec=Agent)
agent.name = 'ChildAgent' agent.name = 'ChildAgent'
agent.llm = MagicMock(spec=LLM) agent.llm = MagicMock(spec=LLM)
agent.llm.service_id = 'main_agent'
agent.llm.metrics = Metrics() agent.llm.metrics = Metrics()
agent.llm.config = LLMConfig() agent.llm.config = LLMConfig()
agent.llm.retry_listener = None # Add retry_listener attribute agent.llm.retry_listener = None # Add retry_listener attribute
agent.config = AgentConfig() agent.config = AgentConfig()
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
system_message = SystemMessageAction(content='Test system message') system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT system_message._source = EventSource.AGENT
@@ -78,15 +111,37 @@ def mock_child_agent():
return agent return agent
def create_mock_agent_factory(mock_child_agent, llm_registry):
"""Helper function to create a mock agent factory with proper LLM registration."""
def create_mock_agent(config, llm_registry=None):
# Register the mock agent's LLM in the registry so get_combined_metrics() can find it
if llm_registry:
mock_child_agent.llm = llm_registry.get_llm('agent_llm', LLMConfig())
mock_child_agent.llm_registry = (
llm_registry # Set the llm_registry attribute
)
return mock_child_agent
return create_mock_agent
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream): async def test_delegation_flow(
"""Test that when the parent agent delegates to a child mock_parent_agent, mock_child_agent, mock_event_stream, connected_registry_and_stats
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly. ):
2. metrics are accumulated globally (delegate is adding to the parents metrics)
3. local metrics for the delegate are still accessible
""" """
Test that when the parent agent delegates to a child
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
2. metrics are accumulated globally via LLM registry (delegate adds to the global metrics)
3. global metrics tracking works correctly through the LLM registry
"""
llm_registry, conversation_stats = connected_registry_and_stats
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent # Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent) Agent.get_cls = Mock(
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
)
step_count = 0 step_count = 0
@@ -97,6 +152,12 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
mock_child_agent.step = agent_step_fn mock_child_agent.step = agent_step_fn
# Set up the parent agent's LLM with initial cost and register it in the registry
# The parent agent's LLM should use the existing registered LLM to ensure proper tracking
parent_llm = llm_registry.service_to_llm['agent']
parent_llm.metrics.accumulated_cost = 2
mock_parent_agent.llm = parent_llm
parent_metrics = Metrics() parent_metrics = Metrics()
parent_metrics.accumulated_cost = 2 parent_metrics.accumulated_cost = 2
# Create parent controller # Create parent controller
@@ -114,6 +175,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
parent_controller = AgentController( parent_controller = AgentController(
agent=mock_parent_agent, agent=mock_parent_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=1, # Add the required iteration_delta parameter iteration_delta=1, # Add the required iteration_delta parameter
sid='parent', sid='parent',
confirmation_mode=False, confirmation_mode=False,
@@ -180,21 +242,23 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
for i in range(4): for i in range(4):
delegate_controller.state.iteration_flag.step() delegate_controller.state.iteration_flag.step()
delegate_controller.agent.step(delegate_controller.state) delegate_controller.agent.step(delegate_controller.state)
# Update the agent's LLM metrics (not the deprecated state metrics)
delegate_controller.agent.llm.metrics.add_cost(1.0) delegate_controller.agent.llm.metrics.add_cost(1.0)
assert ( assert (
delegate_controller.state.get_local_step() == 4 delegate_controller.state.get_local_step() == 4
) # verify local metrics are accessible via snapshot ) # verify local metrics are accessible via snapshot
# Check that the conversation stats has the combined metrics (parent + delegate)
combined_metrics = delegate_controller.state.convo_stats.get_combined_metrics()
assert ( assert (
delegate_controller.state.metrics.accumulated_cost combined_metrics.accumulated_cost
== 6 # Make sure delegate tracks global cost == 6 # Make sure delegate tracks global cost (2 from parent + 4 from delegate)
) )
assert ( # Since metrics are now global via LLM registry, local metrics tracking
delegate_controller.state.get_local_metrics().accumulated_cost # is handled differently. The delegate's LLM shares the same metrics object
== 4 # Delegate spent one dollar per step # as the parent for global tracking, so we verify the global total is correct.
)
delegate_controller.state.outputs = {'delegate_result': 'done'} delegate_controller.state.outputs = {'delegate_result': 'done'}
@@ -228,15 +292,18 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
], ],
) )
async def test_delegate_step_different_states( async def test_delegate_step_different_states(
mock_parent_agent, mock_event_stream, delegate_state mock_parent_agent, mock_event_stream, delegate_state, connected_registry_and_stats
): ):
"""Ensure that delegate is closed or remains open based on the delegate's state.""" """Ensure that delegate is closed or remains open based on the delegate's state."""
llm_registry, conversation_stats = connected_registry_and_stats
# Create a state with iteration_flag.max_value set to 10 # Create a state with iteration_flag.max_value set to 10
state = State(inputs={}) state = State(inputs={})
state.iteration_flag.max_value = 10 state.iteration_flag.max_value = 10
controller = AgentController( controller = AgentController(
agent=mock_parent_agent, agent=mock_parent_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=1, # Add the required iteration_delta parameter iteration_delta=1, # Add the required iteration_delta parameter
sid='test', sid='test',
confirmation_mode=False, confirmation_mode=False,
@@ -292,11 +359,23 @@ async def test_delegate_step_different_states(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delegate_hits_global_limits( async def test_delegate_hits_global_limits(
mock_child_agent, mock_event_stream, mock_parent_agent mock_child_agent, mock_event_stream, mock_parent_agent, connected_registry_and_stats
): ):
"""Global limits from control flags should apply to delegates""" """
Global limits from control flags should apply to delegates
"""
llm_registry, conversation_stats = connected_registry_and_stats
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent # Mock the agent class resolution so that AgentController can instantiate mock_child_agent
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent) Agent.get_cls = Mock(
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
)
# Set up the parent agent's LLM with initial cost and register it in the registry
mock_parent_agent.llm.metrics.accumulated_cost = 2
mock_parent_agent.llm.service_id = 'main_agent'
# Register the parent agent's LLM in the registry
llm_registry.service_to_llm['main_agent'] = mock_parent_agent.llm
parent_metrics = Metrics() parent_metrics = Metrics()
parent_metrics.accumulated_cost = 2 parent_metrics.accumulated_cost = 2
@@ -315,6 +394,7 @@ async def test_delegate_hits_global_limits(
parent_controller = AgentController( parent_controller = AgentController(
agent=mock_parent_agent, agent=mock_parent_agent,
event_stream=mock_event_stream, event_stream=mock_event_stream,
convo_stats=conversation_stats,
iteration_delta=1, # Add the required iteration_delta parameter iteration_delta=1, # Add the required iteration_delta parameter
sid='parent', sid='parent',
confirmation_mode=False, confirmation_mode=False,

View File

@@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber from openhands.events import EventStream, EventStreamSubscriber
from openhands.integrations.service_types import ProviderType from openhands.integrations.service_types import ProviderType
from openhands.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import ( from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
) )
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore from openhands.storage.memory import InMemoryFileStore
@@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
@pytest.fixture @pytest.fixture
def mock_agent(): def mock_llm_registry():
"""Create a properly configured mock agent with all required nested attributes""" """Create a mock LLM registry that properly simulates LLM registration"""
# Create the base mocks config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
@pytest.fixture
def mock_conversation_stats():
"""Create a mock ConversationStats that properly simulates metrics tracking"""
file_store = InMemoryFileStore({})
stats = ConversationStats(
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
)
return stats
@pytest.fixture
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly"""
# Subscribe to LLM registry events to track metrics
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
return mock_llm_registry, mock_conversation_stats
@pytest.fixture
def make_mock_agent():
def _make_mock_agent(llm_registry):
agent = MagicMock(spec=Agent) agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig) agent_config = MagicMock(spec=AgentConfig)
llm_config = LLMConfig(
# Configure the LLM config model='gpt-4o',
llm_config.model = 'test-model' api_key='test_key',
llm_config.base_url = 'http://test' num_retries=2,
llm_config.max_message_chars = 1000 retry_min_wait=1,
retry_max_wait=2,
# Configure the agent config )
agent_config.disabled_microagents = [] agent_config.disabled_microagents = []
agent_config.enable_mcp = True agent_config.enable_mcp = True
llm_registry.service_to_llm.clear()
# Set up the chain of mocks mock_llm = llm_registry.get_llm('agent_llm', llm_config)
llm.metrics = metrics agent.llm = mock_llm
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent' agent.name = 'test-agent'
agent.sandbox_plugins = [] agent.sandbox_plugins = []
agent.config = agent_config agent.config = agent_config
agent.prompt_manager = MagicMock() agent.prompt_manager = MagicMock()
return agent return agent
return _make_mock_agent
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent): async def test_agent_session_start_with_no_state(
make_mock_agent, mock_llm_registry, mock_conversation_stats
):
"""Test that AgentSession.start() works correctly when there's no state to restore""" """Test that AgentSession.start() works correctly when there's no state to restore"""
mock_agent = make_mock_agent(mock_llm_registry)
# Setup # Setup
file_store = InMemoryFileStore({}) file_store = InMemoryFileStore({})
session = AgentSession( session = AgentSession(
sid='test-session', sid='test-session',
file_store=file_store, file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
) )
# Create a mock runtime and set it up # Create a mock runtime and set it up
@@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent): async def test_agent_session_start_with_restored_state(
make_mock_agent, mock_llm_registry, mock_conversation_stats
):
"""Test that AgentSession.start() works correctly when there's a state to restore""" """Test that AgentSession.start() works correctly when there's a state to restore"""
mock_agent = make_mock_agent(mock_llm_registry)
# Setup # Setup
file_store = InMemoryFileStore({}) file_store = InMemoryFileStore({})
session = AgentSession( session = AgentSession(
sid='test-session', sid='test-session',
file_store=file_store, file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
) )
# Create a mock runtime and set it up # Create a mock runtime and set it up
@@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metrics_centralization_and_sharing(mock_agent): async def test_metrics_centralization_via_conversation_stats(
"""Test that metrics are centralized and shared between controller and agent.""" make_mock_agent, connected_registry_and_stats
):
"""Test that metrics are centralized through the ConversationStats service."""
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
mock_agent = make_mock_agent(mock_llm_registry)
# Setup # Setup
file_store = InMemoryFileStore({}) file_store = InMemoryFileStore({})
session = AgentSession( session = AgentSession(
sid='test-session', sid='test-session',
file_store=file_store, file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
) )
# Create a mock runtime and set it up # Create a mock runtime and set it up
@@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
memory = Memory(event_stream=mock_event_stream, sid='test-session') memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir' memory.microagents_dir = 'test-dir'
# The registry already has a real metrics object set up in the fixture
# Patch necessary components # Patch necessary components
with ( with (
patch( patch(
@@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
max_iterations=10, max_iterations=10,
) )
# Verify that the agent's LLM metrics and controller's state metrics are the same object # Verify that the ConversationStats is properly set up
assert session.controller.agent.llm.metrics is session.controller.state.metrics assert session.controller.state.convo_stats is mock_conversation_stats
# Add some metrics to the agent's LLM # Add some metrics to the agent's LLM (simulating LLM usage)
test_cost = 0.05 test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost) session.controller.agent.llm.metrics.add_cost(test_cost)
# Verify that the cost is reflected in the controller's state metrics # Verify that the cost is reflected in the combined metrics from the conversation stats
assert session.controller.state.metrics.accumulated_cost == test_cost combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
assert combined_metrics.accumulated_cost == test_cost
# Create a test metrics object to simulate an observation with metrics # Add more cost to simulate additional LLM usage
test_observation_metrics = Metrics() additional_cost = 0.1
test_observation_metrics.add_cost(0.1) session.controller.agent.llm.metrics.add_cost(additional_cost)
# Get the current accumulated cost before merging # Verify the combined metrics reflect the total cost
current_cost = session.controller.state.metrics.accumulated_cost combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
assert combined_metrics.accumulated_cost == test_cost + additional_cost
# Simulate merging metrics from an observation # Reset the agent and verify that combined metrics are preserved
session.controller.state_tracker.merge_metrics(test_observation_metrics)
# Verify that the merged metrics are reflected in both agent and controller
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
assert (
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
)
# Reset the agent and verify that metrics are not reset
session.controller.agent.reset() session.controller.agent.reset()
# Metrics should still be the same after reset # Combined metrics should still be preserved after agent reset
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1 assert (
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1 session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
assert session.controller.agent.llm.metrics is session.controller.state.metrics == test_cost + additional_cost
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_budget_control_flag_syncs_with_metrics(mock_agent): async def test_budget_control_flag_syncs_with_metrics(
make_mock_agent, connected_registry_and_stats
):
"""Test that BudgetControlFlag's current value matches the accumulated costs.""" """Test that BudgetControlFlag's current value matches the accumulated costs."""
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
mock_agent = make_mock_agent(mock_llm_registry)
# Setup # Setup
file_store = InMemoryFileStore({}) file_store = InMemoryFileStore({})
session = AgentSession( session = AgentSession(
sid='test-session', sid='test-session',
file_store=file_store, file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
) )
# Create a mock runtime and set it up # Create a mock runtime and set it up
@@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
memory = Memory(event_stream=mock_event_stream, sid='test-session') memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir' memory.microagents_dir = 'test-dir'
# The registry already has a real metrics object set up in the fixture
# Patch necessary components # Patch necessary components
with ( with (
patch( patch(
@@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
assert session.controller.state.budget_flag.max_value == 1.0 assert session.controller.state.budget_flag.max_value == 1.0
assert session.controller.state.budget_flag.current_value == 0.0 assert session.controller.state.budget_flag.current_value == 0.0
# Add some metrics to the agent's LLM # Add some metrics to the agent's LLM (simulating LLM usage)
test_cost = 0.05 test_cost = 0.05
session.controller.agent.llm.metrics.add_cost(test_cost) session.controller.agent.llm.metrics.add_cost(test_cost)
@@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
session.controller.state_tracker.sync_budget_flag_with_metrics() session.controller.state_tracker.sync_budget_flag_with_metrics()
assert session.controller.state.budget_flag.current_value == test_cost assert session.controller.state.budget_flag.current_value == test_cost
# Create a test metrics object to simulate an observation with metrics # Add more cost to simulate additional LLM usage
test_observation_metrics = Metrics() additional_cost = 0.1
test_observation_metrics.add_cost(0.1) session.controller.agent.llm.metrics.add_cost(additional_cost)
# Simulate merging metrics from an observation # Sync again and verify the budget flag is updated
session.controller.state_tracker.merge_metrics(test_observation_metrics) session.controller.state_tracker.sync_budget_flag_with_metrics()
assert (
session.controller.state.budget_flag.current_value
== test_cost + additional_cost
)
# Verify that the budget control flag's current value is updated to match the new accumulated cost # Reset the agent and verify that budget flag still reflects the accumulated cost
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
# Reset the agent and verify that metrics and budget flag are not reset
session.controller.agent.reset() session.controller.agent.reset()
# Budget control flag should still reflect the accumulated cost after reset # Budget control flag should still reflect the accumulated cost after reset
assert session.controller.state.budget_flag.current_value == test_cost + 0.1 session.controller.state_tracker.sync_budget_flag_with_metrics()
assert (
session.controller.state.budget_flag.current_value
== test_cost + additional_cost
)
def test_override_provider_tokens_with_custom_secret(): def test_override_provider_tokens_with_custom_secret(
mock_llm_registry, mock_conversation_stats
):
"""Test that override_provider_tokens_with_custom_secret works correctly. """Test that override_provider_tokens_with_custom_secret works correctly.
This test verifies that the method properly removes provider tokens when This test verifies that the method properly removes provider tokens when
@@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
session = AgentSession( session = AgentSession(
sid='test-session', sid='test-session',
file_store=file_store, file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=mock_conversation_stats,
) )
# Create test data # Create test data

View File

@@ -30,6 +30,7 @@ from openhands.agenthub.readonly_agent.tools import (
) )
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.core.config import AgentConfig, LLMConfig from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.exceptions import FunctionCallNotExistsError from openhands.core.exceptions import FunctionCallNotExistsError
from openhands.core.message import ImageContent, Message, TextContent from openhands.core.message import ImageContent, Message, TextContent
from openhands.events.action import ( from openhands.events.action import (
@@ -42,10 +43,20 @@ from openhands.events.observation.commands import (
CmdOutputObservation, CmdOutputObservation,
) )
from openhands.events.tool import ToolCallMetadata from openhands.events.tool import ToolCallMetadata
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser import View from openhands.memory.condenser import View
@pytest.fixture
def create_llm_registry():
def _get_registry(llm_config):
config = OpenHandsConfig()
config.set_llm_config(llm_config)
return LLMRegistry(config=config)
return _get_registry
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent']) @pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
def agent_class(request): def agent_class(request):
if request.param == 'CodeActAgent': if request.param == 'CodeActAgent':
@@ -57,18 +68,22 @@ def agent_class(request):
@pytest.fixture @pytest.fixture
def agent(agent_class) -> Union[CodeActAgent, ReadOnlyAgent]: def agent(agent_class, create_llm_registry) -> Union[CodeActAgent, ReadOnlyAgent]:
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
config = AgentConfig() config = AgentConfig()
agent = agent_class(llm=LLM(LLMConfig()), config=config) agent = agent_class(config=config, llm_registry=create_llm_registry(llm_config))
agent.llm = Mock() agent.llm = Mock()
agent.llm.config = Mock() agent.llm.config = Mock()
agent.llm.config.max_message_chars = 1000 agent.llm.config.max_message_chars = 1000
return agent return agent
def test_agent_with_default_config_has_default_tools(): def test_agent_with_default_config_has_default_tools(create_llm_registry):
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
config = AgentConfig() config = AgentConfig()
codeact_agent = CodeActAgent(llm=LLM(LLMConfig()), config=config) codeact_agent = CodeActAgent(
config=config, llm_registry=create_llm_registry(llm_config)
)
assert len(codeact_agent.tools) > 0 assert len(codeact_agent.tools) > 0
default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools] default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools]
assert { assert {
@@ -231,7 +246,7 @@ def test_response_to_actions_invalid_tool():
readonly_response_to_actions(mock_response) readonly_response_to_actions(mock_response)
def test_step_with_no_pending_actions(mock_state: State): def test_step_with_no_pending_actions(mock_state: State, create_llm_registry):
# Mock the LLM response # Mock the LLM response
mock_response = Mock() mock_response = Mock()
mock_response.id = 'mock_id' mock_response.id = 'mock_id'
@@ -252,9 +267,12 @@ def test_step_with_no_pending_actions(mock_state: State):
llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting
# Create agent with mocked LLM # Create agent with mocked LLM
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
config = AgentConfig() config = AgentConfig()
config.enable_prompt_extensions = False config.enable_prompt_extensions = False
agent = CodeActAgent(llm=llm, config=config) agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
# Replace the LLM with our mock after creation
agent.llm = llm
# Test step with no pending actions # Test step with no pending actions
mock_state.latest_user_message = None mock_state.latest_user_message = None
@@ -281,15 +299,10 @@ def test_step_with_no_pending_actions(mock_state: State):
@pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent']) @pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent'])
def test_correct_tool_description_loaded_based_on_model_name( def test_correct_tool_description_loaded_based_on_model_name(
agent_type, mock_state: State agent_type, create_llm_registry
): ):
"""Tests that the simplified tool descriptions are loaded for specific models.""" """Tests that the simplified tool descriptions are loaded for specific models."""
o3_mock_config = Mock() o3_mock_config = LLMConfig(model='mock_o3_model', api_key='test_key')
o3_mock_config.model = 'mock_o3_model'
llm = Mock()
llm.config = o3_mock_config
if agent_type == 'CodeActAgent': if agent_type == 'CodeActAgent':
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
@@ -299,16 +312,19 @@ def test_correct_tool_description_loaded_based_on_model_name(
agent_class = ReadOnlyAgent agent_class = ReadOnlyAgent
agent = agent_class(llm=llm, config=AgentConfig()) agent = agent_class(
config=AgentConfig(),
llm_registry=create_llm_registry(o3_mock_config),
)
for tool in agent.tools: for tool in agent.tools:
# Assert all descriptions have less than 1024 characters # Assert all descriptions have less than 1024 characters
assert len(tool['function']['description']) < 1024 assert len(tool['function']['description']) < 1024
sonnet_mock_config = Mock() sonnect_mock_config = LLMConfig(model='mock_sonnet_model', api_key='test_key')
sonnet_mock_config.model = 'mock_sonnet_model' agent = agent_class(
config=AgentConfig(),
llm.config = sonnet_mock_config llm_registry=create_llm_registry(sonnect_mock_config),
agent = agent_class(llm=llm, config=AgentConfig()) )
# Assert existence of the detailed tool descriptions that are longer than 1024 characters # Assert existence of the detailed tool descriptions that are longer than 1024 characters
if agent_type == 'CodeActAgent': if agent_type == 'CodeActAgent':
# This only holds for CodeActAgent # This only holds for CodeActAgent
@@ -481,10 +497,12 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
assert isinstance(enhanced_messages[5].content[0], ImageContent) assert isinstance(enhanced_messages[5].content[0], ImageContent)
def test_get_system_message(): def test_get_system_message(create_llm_registry):
"""Test that the Agent.get_system_message method returns a SystemMessageAction.""" """Test that the Agent.get_system_message method returns a SystemMessageAction."""
# Create a mock agent # Create a mock agent
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig()) config = AgentConfig()
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
result = agent.get_system_message() result = agent.get_system_message()

View File

@@ -34,7 +34,7 @@ def test_completion_retries_api_connection_error(
] ]
# Create an LLM instance and call completion # Create an LLM instance and call completion
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
response = llm.completion( response = llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False, stream=False,
@@ -70,7 +70,7 @@ def test_completion_max_retries_api_connection_error(
] ]
# Create an LLM instance and call completion # Create an LLM instance and call completion
llm = LLM(config=default_config) llm = LLM(config=default_config, service_id='test-service')
# The completion should raise an APIConnectionError after exhausting all retries # The completion should raise an APIConnectionError after exhausting all retries
with pytest.raises(APIConnectionError) as excinfo: with pytest.raises(APIConnectionError) as excinfo:

View File

@@ -5,11 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
from openhands.events.event import EventSource from openhands.events.event import EventSource
from openhands.events.event_store import EventStore from openhands.events.event_store import EventStore
from openhands.llm.llm_registry import LLMRegistry
from openhands.server.conversation_manager.standalone_conversation_manager import ( from openhands.server.conversation_manager.standalone_conversation_manager import (
StandaloneConversationManager, StandaloneConversationManager,
) )
@@ -24,6 +24,7 @@ async def test_auto_generate_title_with_llm():
"""Test auto-generating a title using LLM.""" """Test auto-generating a title using LLM."""
# Mock dependencies # Mock dependencies
file_store = InMemoryFileStore() file_store = InMemoryFileStore()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation with a user message # Create test conversation with a user message
conversation_id = 'test-conversation' conversation_id = 'test-conversation'
@@ -46,13 +47,10 @@ async def test_auto_generate_title_with_llm():
mock_event_store.search_events.return_value = [user_message] mock_event_store.search_events.return_value = [user_message]
mock_event_store_cls.return_value = mock_event_store mock_event_store_cls.return_value = mock_event_store
# Mock the LLM response # Mock the LLM registry response
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls: llm_registry.request_extraneous_completion.return_value = (
mock_llm = mock_llm_cls.return_value 'Python Data Analysis Script'
mock_response = MagicMock() )
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'Python Data Analysis Script'
mock_llm.completion.return_value = mock_response
# Create test settings with LLM config # Create test settings with LLM config
settings = Settings( settings = Settings(
@@ -63,7 +61,7 @@ async def test_auto_generate_title_with_llm():
# Call the auto_generate_title function directly # Call the auto_generate_title function directly
title = await auto_generate_title( title = await auto_generate_title(
conversation_id, user_id, file_store, settings conversation_id, user_id, file_store, settings, llm_registry
) )
# Verify the result # Verify the result
@@ -74,15 +72,8 @@ async def test_auto_generate_title_with_llm():
conversation_id, file_store, user_id conversation_id, file_store, user_id
) )
# Verify LLM was called with appropriate parameters # Verify LLM registry was called with appropriate parameters
mock_llm_cls.assert_called_once_with( llm_registry.request_extraneous_completion.assert_called_once()
LLMConfig(
model='test-model',
api_key='test-key',
base_url='test-url',
)
)
mock_llm.completion.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -90,6 +81,7 @@ async def test_auto_generate_title_fallback():
"""Test auto-generating a title with fallback to truncation when LLM fails.""" """Test auto-generating a title with fallback to truncation when LLM fails."""
# Mock dependencies # Mock dependencies
file_store = InMemoryFileStore() file_store = InMemoryFileStore()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation with a user message # Create test conversation with a user message
conversation_id = 'test-conversation' conversation_id = 'test-conversation'
@@ -111,10 +103,8 @@ async def test_auto_generate_title_fallback():
mock_event_store.search_events.return_value = [user_message] mock_event_store.search_events.return_value = [user_message]
mock_event_store_cls.return_value = mock_event_store mock_event_store_cls.return_value = mock_event_store
# Mock the LLM to raise an exception # Mock the LLM registry to raise an exception
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls: llm_registry.request_extraneous_completion.side_effect = Exception('Test error')
mock_llm = mock_llm_cls.return_value
mock_llm.completion.side_effect = Exception('Test error')
# Create test settings with LLM config # Create test settings with LLM config
settings = Settings( settings = Settings(
@@ -125,7 +115,7 @@ async def test_auto_generate_title_fallback():
# Call the auto_generate_title function directly # Call the auto_generate_title function directly
title = await auto_generate_title( title = await auto_generate_title(
conversation_id, user_id, file_store, settings conversation_id, user_id, file_store, settings, llm_registry
) )
# Verify the result is a truncated version of the message # Verify the result is a truncated version of the message
@@ -143,6 +133,7 @@ async def test_auto_generate_title_no_messages():
"""Test auto-generating a title when there are no user messages.""" """Test auto-generating a title when there are no user messages."""
# Mock dependencies # Mock dependencies
file_store = InMemoryFileStore() file_store = InMemoryFileStore()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation with no messages # Create test conversation with no messages
conversation_id = 'test-conversation' conversation_id = 'test-conversation'
@@ -166,7 +157,7 @@ async def test_auto_generate_title_no_messages():
# Call the auto_generate_title function directly # Call the auto_generate_title function directly
title = await auto_generate_title( title = await auto_generate_title(
conversation_id, user_id, file_store, settings conversation_id, user_id, file_store, settings, llm_registry
) )
# Verify the result is empty # Verify the result is empty
@@ -186,6 +177,7 @@ async def test_update_conversation_with_title():
sio.emit = AsyncMock() sio.emit = AsyncMock()
file_store = InMemoryFileStore() file_store = InMemoryFileStore()
server_config = MagicMock() server_config = MagicMock()
llm_registry = MagicMock(spec=LLMRegistry)
# Create test conversation # Create test conversation
conversation_id = 'test-conversation' conversation_id = 'test-conversation'
@@ -222,7 +214,9 @@ async def test_update_conversation_with_title():
AsyncMock(return_value='Generated Title'), AsyncMock(return_value='Generated Title'),
): ):
# Call the method # Call the method
await manager._update_conversation_for_event(user_id, conversation_id, settings) await manager._update_conversation_for_event(
user_id, conversation_id, settings, llm_registry
)
# Verify the title was updated # Verify the title was updated
assert mock_metadata.title == 'Generated Title' assert mock_metadata.title == 'Generated Title'

View File

@@ -6,6 +6,7 @@ import pytest_asyncio
from openhands.cli import main as cli from openhands.cli import main as cli
from openhands.controller.state.state import State from openhands.controller.state.state import State
from openhands.core.config.llm_config import LLMConfig
from openhands.events import EventSource from openhands.events import EventSource
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
@@ -124,12 +125,14 @@ def mock_config():
'' # Empty string, not starting with 'tvly-' '' # Empty string, not starting with 'tvly-'
) )
config.search_api_key = search_api_key_mock config.search_api_key = search_api_key_mock
config.get_llm_config_from_agent.return_value = LLMConfig(model='model')
# Mock sandbox with volumes attribute to prevent finalize_config issues # Mock sandbox with volumes attribute to prevent finalize_config issues
config.sandbox = MagicMock() config.sandbox = MagicMock()
config.sandbox.volumes = ( config.sandbox.volumes = (
None # This prevents finalize_config from overriding workspace_base None # This prevents finalize_config from overriding workspace_base
) )
config.model_name = 'model'
return config return config
@@ -213,7 +216,11 @@ async def test_run_session_without_initial_action(
# Assertions for initialization flow # Assertions for initialization flow
mock_display_runtime_init.assert_called_once_with('local') mock_display_runtime_init.assert_called_once_with('local')
mock_display_animation.assert_called_once() mock_display_animation.assert_called_once()
mock_create_agent.assert_called_once_with(mock_config) # Check that mock_config is the first parameter to create_agent
mock_create_agent.assert_called_once()
assert mock_create_agent.call_args[0][0] == mock_config, (
'First parameter to create_agent should be mock_config'
)
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory) mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
mock_create_runtime.assert_called_once() mock_create_runtime.assert_called_once()
mock_create_controller.assert_called_once() mock_create_controller.assert_called_once()

View File

@@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from litellm.exceptions import AuthenticationError from litellm.exceptions import AuthenticationError
from pydantic import SecretStr
from openhands.cli import main as cli from openhands.cli import main as cli
from openhands.core.config.llm_config import LLMConfig
from openhands.events import EventSource from openhands.events import EventSource
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
@@ -45,11 +47,10 @@ def mock_config():
config.workspace_base = '/test/dir' config.workspace_base = '/test/dir'
# Set up LLM config to use OpenHands provider # Set up LLM config to use OpenHands provider
llm_config = MagicMock() llm_config = LLMConfig(model='openhands/o3', api_key=SecretStr('invalid-api-key'))
llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model
llm_config.api_key = MagicMock() config.get_llm_config.return_value = llm_config
llm_config.api_key.get_secret_value.return_value = 'invalid-api-key' config.get_llm_config_from_agent.return_value = llm_config
config.llm = llm_config
# Mock search_api_key with get_secret_value method # Mock search_api_key with get_secret_value method
search_api_key_mock = MagicMock() search_api_key_mock = MagicMock()

View File

@@ -13,6 +13,7 @@ from openhands.core.config.mcp_config import (
from openhands.events.action.mcp import MCPAction from openhands.events.action.mcp import MCPAction
from openhands.events.observation import ErrorObservation from openhands.events.observation import ErrorObservation
from openhands.events.observation.mcp import MCPObservation from openhands.events.observation.mcp import MCPObservation
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
@@ -23,8 +24,12 @@ class TestCLIRuntimeMCP:
"""Set up test fixtures.""" """Set up test fixtures."""
self.config = OpenHandsConfig() self.config = OpenHandsConfig()
self.event_stream = MagicMock() self.event_stream = MagicMock()
llm_registry = LLMRegistry(config=OpenHandsConfig())
self.runtime = CLIRuntime( self.runtime = CLIRuntime(
config=self.config, event_stream=self.event_stream, sid='test-session' config=self.config,
event_stream=self.event_stream,
sid='test-session',
llm_registry=llm_registry,
) )
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -7,10 +7,18 @@ import pytest
from openhands.core.config import OpenHandsConfig from openhands.core.config import OpenHandsConfig
from openhands.events import EventStream from openhands.events import EventStream
# Mock LLMRegistry
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
from openhands.storage import get_file_store from openhands.storage import get_file_store
# Create a mock LLMRegistry class
class MockLLMRegistry:
def __init__(self, config):
self.config = config
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
"""Create a temporary directory for testing.""" """Create a temporary directory for testing."""
@@ -25,7 +33,8 @@ def cli_runtime(temp_dir):
event_stream = EventStream('test', file_store) event_stream = EventStream('test', file_store)
config = OpenHandsConfig() config = OpenHandsConfig()
config.workspace_base = temp_dir config.workspace_base = temp_dir
runtime = CLIRuntime(config, event_stream) llm_registry = MockLLMRegistry(config)
runtime = CLIRuntime(config, event_stream, llm_registry)
runtime._runtime_initialized = True # Skip initialization runtime._runtime_initialized = True # Skip initialization
return runtime return runtime

View File

@@ -17,6 +17,7 @@ from openhands.core.config.condenser_config import (
StructuredSummaryCondenserConfig, StructuredSummaryCondenserConfig,
) )
from openhands.core.config.llm_config import LLMConfig from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.message import Message, TextContent from openhands.core.message import Message, TextContent
from openhands.core.schema.action import ActionType from openhands.core.schema.action import ActionType
from openhands.events.event import Event, EventSource from openhands.events.event import Event, EventSource
@@ -24,6 +25,7 @@ from openhands.events.observation import BrowserOutputObservation
from openhands.events.observation.agent import AgentCondensationObservation from openhands.events.observation.agent import AgentCondensationObservation
from openhands.events.observation.observation import Observation from openhands.events.observation.observation import Observation
from openhands.llm import LLM from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.memory.condenser import Condenser from openhands.memory.condenser import Condenser
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
from openhands.memory.condenser.impl import ( from openhands.memory.condenser.impl import (
@@ -38,6 +40,7 @@ from openhands.memory.condenser.impl import (
StructuredSummaryCondenser, StructuredSummaryCondenser,
) )
from openhands.memory.condenser.impl.pipeline import CondenserPipeline from openhands.memory.condenser.impl.pipeline import CondenserPipeline
from openhands.server.services.conversation_stats import ConversationStats
def create_test_event( def create_test_event(
@@ -56,12 +59,15 @@ def create_test_event(
@pytest.fixture @pytest.fixture
def mock_llm() -> LLM: def mock_llm() -> LLM:
"""Mocks an LLM object with a utility function for setting and resetting response contents in unit tests.""" """Mocks an LLM object with a utility function for setting and resetting response contents in unit tests."""
# Create a real LLMConfig instead of a mock to properly handle SecretStr api_key
real_config = LLMConfig(
model='gpt-4o', api_key='test_key', custom_llm_provider=None
)
# Create a MagicMock for the LLM object # Create a MagicMock for the LLM object
mock_llm = MagicMock( mock_llm = MagicMock(
spec=LLM, spec=LLM,
config=MagicMock( config=real_config,
spec=LLMConfig, model='gpt-4o', api_key='test_key', custom_llm_provider=None
),
metrics=MagicMock(), metrics=MagicMock(),
) )
_mock_content = None _mock_content = None
@@ -95,6 +101,23 @@ def mock_llm() -> LLM:
return mock_llm return mock_llm
@pytest.fixture
def mock_conversation_stats() -> ConversationStats:
"""Creates a mock ConversationStats service."""
mock_stats = MagicMock(spec=ConversationStats)
return mock_stats
@pytest.fixture
def mock_llm_registry(mock_llm, mock_conversation_stats) -> LLMRegistry:
"""Creates an actual LLMRegistry that returns real LLMs."""
# Create an actual LLMRegistry with a basic OpenHandsConfig
config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
class RollingCondenserTestHarness: class RollingCondenserTestHarness:
"""Test harness for rolling condensers. """Test harness for rolling condensers.
@@ -165,10 +188,10 @@ class RollingCondenserTestHarness:
return ((index - max_size) // target_size) + 1 return ((index - max_size) // target_size) + 1
def test_noop_condenser_from_config(): def test_noop_condenser_from_config(mock_llm_registry):
"""Test that the NoOpCondenser objects can be made from config.""" """Test that the NoOpCondenser objects can be made from config."""
config = NoOpCondenserConfig() config = NoOpCondenserConfig()
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, NoOpCondenser) assert isinstance(condenser, NoOpCondenser)
@@ -189,11 +212,11 @@ def test_noop_condenser():
assert result == View(events=events) assert result == View(events=events)
def test_observation_masking_condenser_from_config(): def test_observation_masking_condenser_from_config(mock_llm_registry):
"""Test that ObservationMaskingCondenser objects can be made from config.""" """Test that ObservationMaskingCondenser objects can be made from config."""
attention_window = 5 attention_window = 5
config = ObservationMaskingCondenserConfig(attention_window=attention_window) config = ObservationMaskingCondenserConfig(attention_window=attention_window)
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, ObservationMaskingCondenser) assert isinstance(condenser, ObservationMaskingCondenser)
assert condenser.attention_window == attention_window assert condenser.attention_window == attention_window
@@ -229,11 +252,11 @@ def test_observation_masking_condenser_respects_attention_window():
assert event == condensed_event assert event == condensed_event
def test_browser_output_condenser_from_config(): def test_browser_output_condenser_from_config(mock_llm_registry):
"""Test that BrowserOutputCondenser objects can be made from config.""" """Test that BrowserOutputCondenser objects can be made from config."""
attention_window = 5 attention_window = 5
config = BrowserOutputCondenserConfig(attention_window=attention_window) config = BrowserOutputCondenserConfig(attention_window=attention_window)
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, BrowserOutputCondenser) assert isinstance(condenser, BrowserOutputCondenser)
assert condenser.attention_window == attention_window assert condenser.attention_window == attention_window
@@ -271,12 +294,12 @@ def test_browser_output_condenser_respects_attention_window():
assert event == condensed_event assert event == condensed_event
def test_recent_events_condenser_from_config(): def test_recent_events_condenser_from_config(mock_llm_registry):
"""Test that RecentEventsCondenser objects can be made from config.""" """Test that RecentEventsCondenser objects can be made from config."""
max_events = 5 max_events = 5
keep_first = True keep_first = True
config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events) config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events)
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, RecentEventsCondenser) assert isinstance(condenser, RecentEventsCondenser)
assert condenser.max_events == max_events assert condenser.max_events == max_events
@@ -334,14 +357,14 @@ def test_recent_events_condenser():
assert result[2]._message == 'Event 5' # kept from max_events assert result[2]._message == 'Event 5' # kept from max_events
def test_llm_summarizing_condenser_from_config(): def test_llm_summarizing_condenser_from_config(mock_llm_registry):
"""Test that LLMSummarizingCondenser objects can be made from config.""" """Test that LLMSummarizingCondenser objects can be made from config."""
config = LLMSummarizingCondenserConfig( config = LLMSummarizingCondenserConfig(
max_size=50, max_size=50,
keep_first=10, keep_first=10,
llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True), llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True),
) )
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, LLMSummarizingCondenser) assert isinstance(condenser, LLMSummarizingCondenser)
assert condenser.llm.config.model == 'gpt-4o' assert condenser.llm.config.model == 'gpt-4o'
@@ -349,25 +372,33 @@ def test_llm_summarizing_condenser_from_config():
assert condenser.max_size == 50 assert condenser.max_size == 50
assert condenser.keep_first == 10 assert condenser.keep_first == 10
# Since this condenser can't take advantage of caching, we intercept the
# passed config and manually flip the caching prompt to False.
assert not condenser.llm.config.caching_prompt
def test_llm_summarizing_condenser_invalid_config(mock_llm, mock_llm_registry):
def test_llm_summarizing_condenser_invalid_config():
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size.""" """Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
pytest.raises( pytest.raises(
ValueError, ValueError,
LLMSummarizingCondenser, LLMSummarizingCondenser,
llm=MagicMock(), llm=mock_llm,
max_size=4, max_size=4,
keep_first=2, keep_first=2,
) )
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0) pytest.raises(
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1) ValueError,
LLMSummarizingCondenser,
llm=mock_llm,
max_size=0,
)
pytest.raises(
ValueError,
LLMSummarizingCondenser,
llm=mock_llm,
keep_first=-1,
)
def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm): def test_llm_summarizing_condenser_gives_expected_view_size(
mock_llm, mock_llm_registry
):
"""Test that LLMSummarizingCondenser maintains the correct view size.""" """Test that LLMSummarizingCondenser maintains the correct view size."""
max_size = 10 max_size = 10
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm) condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
@@ -383,12 +414,16 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size) assert len(view) == harness.expected_size(i, max_size)
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm): def test_llm_summarizing_condenser_keeps_first_and_summary_events(
mock_llm, mock_llm_registry
):
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events.""" """Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
max_size = 10 max_size = 10
keep_first = 3 keep_first = 3
condenser = LLMSummarizingCondenser( condenser = LLMSummarizingCondenser(
max_size=max_size, keep_first=keep_first, llm=mock_llm max_size=max_size,
keep_first=keep_first,
llm=mock_llm,
) )
mock_llm.set_mock_response_content('Summary of forgotten events') mock_llm.set_mock_response_content('Summary of forgotten events')
@@ -412,14 +447,14 @@ def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
assert isinstance(view[keep_first], AgentCondensationObservation) assert isinstance(view[keep_first], AgentCondensationObservation)
def test_amortized_forgetting_condenser_from_config(): def test_amortized_forgetting_condenser_from_config(mock_llm_registry):
"""Test that AmortizedForgettingCondenser objects can be made from config.""" """Test that AmortizedForgettingCondenser objects can be made from config."""
max_size = 50 max_size = 50
keep_first = 10 keep_first = 10
config = AmortizedForgettingCondenserConfig( config = AmortizedForgettingCondenserConfig(
max_size=max_size, keep_first=keep_first max_size=max_size, keep_first=keep_first
) )
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, AmortizedForgettingCondenser) assert isinstance(condenser, AmortizedForgettingCondenser)
assert condenser.max_size == max_size assert condenser.max_size == max_size
@@ -475,7 +510,7 @@ def test_amortized_forgetting_condenser_keeps_first_and_last_events():
assert view[:keep_first] == events[: min(keep_first, i + 1)] assert view[:keep_first] == events[: min(keep_first, i + 1)]
def test_llm_attention_condenser_from_config(): def test_llm_attention_condenser_from_config(mock_llm_registry):
"""Test that LLMAttentionCondenser objects can be made from config.""" """Test that LLMAttentionCondenser objects can be made from config."""
config = LLMAttentionCondenserConfig( config = LLMAttentionCondenserConfig(
max_size=50, max_size=50,
@@ -486,37 +521,32 @@ def test_llm_attention_condenser_from_config():
caching_prompt=True, caching_prompt=True,
), ),
) )
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, LLMAttentionCondenser) assert isinstance(condenser, LLMAttentionCondenser)
assert condenser.llm.config.model == 'gpt-4o' assert condenser.llm.config.model == 'gpt-4o'
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
assert condenser.max_size == 50 assert condenser.max_size == 50
assert condenser.keep_first == 10 assert condenser.keep_first == 10
# Since this condenser can't take advantage of caching, we intercept the # Create a mock LLM that doesn't support function calling
# passed config and manually flip the caching prompt to False. mock_llm = MagicMock()
assert not condenser.llm.config.caching_prompt mock_llm.is_function_calling_active.return_value = False
# Create a new registry that returns our mock LLM that doesn't support function calling
mock_registry = MagicMock(spec=LLMRegistry)
mock_registry.get_llm.return_value = mock_llm
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config, mock_registry)
def test_llm_attention_condenser_invalid_config(): def test_llm_attention_condenser_gives_expected_view_size(mock_llm, mock_llm_registry):
"""Test that LLMAttentionCondenser raises an error if the configured LLM doesn't support response schema."""
config = LLMAttentionCondenserConfig(
max_size=50,
keep_first=10,
llm_config=LLMConfig(
model='claude-2', # Older model that doesn't support response schema
api_key='test_key',
),
)
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config)
def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
"""Test that the LLMAttentionCondenser gives views of the expected size.""" """Test that the LLMAttentionCondenser gives views of the expected size."""
max_size = 10 max_size = 10
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@@ -534,10 +564,16 @@ def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size) assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_events_outside_history(mock_llm): def test_llm_attention_condenser_handles_events_outside_history(
mock_llm, mock_llm_registry
):
"""Test that the LLMAttentionCondenser handles event IDs that aren't from the event history.""" """Test that the LLMAttentionCondenser handles event IDs that aren't from the event history."""
max_size = 2 max_size = 2
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@@ -555,10 +591,14 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
assert len(view) == harness.expected_size(i, max_size) assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_too_many_events(mock_llm): def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_llm_registry):
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs.""" """Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
max_size = 2 max_size = 2
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@@ -576,12 +616,16 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm):
assert len(view) == harness.expected_size(i, max_size) assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_too_few_events(mock_llm): def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_llm_registry):
"""Test that the LLMAttentionCondenser handles when the response contains too few event IDs.""" """Test that the LLMAttentionCondenser handles when the response contains too few event IDs."""
max_size = 2 max_size = 2
# Developer note: We must specify keep_first=0 because # Developer note: We must specify keep_first=0 because
# keep_first (1) >= max_size//2 (1) is invalid. # keep_first (1) >= max_size//2 (1) is invalid.
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) condenser = LLMAttentionCondenser(
max_size=max_size,
keep_first=0,
llm=mock_llm,
)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@@ -597,12 +641,14 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm):
assert len(view) == harness.expected_size(i, max_size) assert len(view) == harness.expected_size(i, max_size)
def test_llm_attention_condenser_handles_keep_first_events(mock_llm): def test_llm_attention_condenser_handles_keep_first_events(mock_llm, mock_llm_registry):
"""Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size).""" """Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size)."""
max_size = 12 max_size = 12
keep_first = 4 keep_first = 4
condenser = LLMAttentionCondenser( condenser = LLMAttentionCondenser(
max_size=max_size, keep_first=keep_first, llm=mock_llm max_size=max_size,
keep_first=keep_first,
llm=mock_llm,
) )
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@@ -620,7 +666,7 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
assert view[:keep_first] == events[: min(keep_first, i + 1)] assert view[:keep_first] == events[: min(keep_first, i + 1)]
def test_structured_summary_condenser_from_config(): def test_structured_summary_condenser_from_config(mock_llm_registry):
"""Test that StructuredSummaryCondenser objects can be made from config.""" """Test that StructuredSummaryCondenser objects can be made from config."""
config = StructuredSummaryCondenserConfig( config = StructuredSummaryCondenserConfig(
max_size=50, max_size=50,
@@ -631,7 +677,7 @@ def test_structured_summary_condenser_from_config():
caching_prompt=True, caching_prompt=True,
), ),
) )
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, StructuredSummaryCondenser) assert isinstance(condenser, StructuredSummaryCondenser)
assert condenser.llm.config.model == 'gpt-4o' assert condenser.llm.config.model == 'gpt-4o'
@@ -639,40 +685,55 @@ def test_structured_summary_condenser_from_config():
assert condenser.max_size == 50 assert condenser.max_size == 50
assert condenser.keep_first == 10 assert condenser.keep_first == 10
# Since this condenser can't take advantage of caching, we intercept the
# passed config and manually flip the caching prompt to False.
assert not condenser.llm.config.caching_prompt
def test_structured_summary_condenser_invalid_config(mock_llm):
def test_structured_summary_condenser_invalid_config():
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size.""" """Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
# Since the condenser only works when function calling is on, we need to # Since the condenser only works when function calling is on, we need to
# mock up the check for that. # mock up the check for that.
llm = MagicMock() mock_llm.is_function_calling_active.return_value = True
llm.is_function_calling_active.return_value = True
pytest.raises( pytest.raises(
ValueError, ValueError,
StructuredSummaryCondenser, StructuredSummaryCondenser,
llm=llm, llm=mock_llm,
max_size=4, max_size=4,
keep_first=2, keep_first=2,
) )
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0) pytest.raises(
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1) ValueError,
StructuredSummaryCondenser,
llm=mock_llm,
max_size=0,
)
pytest.raises(
ValueError,
StructuredSummaryCondenser,
llm=mock_llm,
keep_first=-1,
)
# If all other parameters are good but there's no function calling the # If all other parameters are good but there's no function calling the
# condenser still counts as improperly configured. # condenser still counts as improperly configured.
llm.is_function_calling_active.return_value = False # Create a mock LLM that doesn't support function calling
mock_llm_no_func = MagicMock()
mock_llm_no_func.is_function_calling_active.return_value = False
pytest.raises( pytest.raises(
ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2 ValueError,
StructuredSummaryCondenser,
llm=mock_llm_no_func,
max_size=40,
keep_first=2,
) )
def test_structured_summary_condenser_gives_expected_view_size(mock_llm): def test_structured_summary_condenser_gives_expected_view_size(
mock_llm, mock_llm_registry
):
"""Test that StructuredSummaryCondenser maintains the correct view size.""" """Test that StructuredSummaryCondenser maintains the correct view size."""
max_size = 10 max_size = 10
mock_llm.is_function_calling_active.return_value = True
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm) condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)] events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
@@ -686,12 +747,17 @@ def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
assert len(view) == harness.expected_size(i, max_size) assert len(view) == harness.expected_size(i, max_size)
def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm): def test_structured_summary_condenser_keeps_first_and_summary_events(
mock_llm, mock_llm_registry
):
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events.""" """Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
max_size = 10 max_size = 10
keep_first = 3 keep_first = 3
mock_llm.is_function_calling_active.return_value = True
condenser = StructuredSummaryCondenser( condenser = StructuredSummaryCondenser(
max_size=max_size, keep_first=keep_first, llm=mock_llm max_size=max_size,
keep_first=keep_first,
llm=mock_llm,
) )
mock_llm.set_mock_response_content('Summary of forgotten events') mock_llm.set_mock_response_content('Summary of forgotten events')
@@ -715,7 +781,7 @@ def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
assert isinstance(view[keep_first], AgentCondensationObservation) assert isinstance(view[keep_first], AgentCondensationObservation)
def test_condenser_pipeline_from_config(): def test_condenser_pipeline_from_config(mock_llm_registry):
"""Test that CondenserPipeline condensers can be created from configuration objects.""" """Test that CondenserPipeline condensers can be created from configuration objects."""
config = CondenserPipelineConfig( config = CondenserPipelineConfig(
condensers=[ condensers=[
@@ -728,7 +794,7 @@ def test_condenser_pipeline_from_config():
), ),
] ]
) )
condenser = Condenser.from_config(config) condenser = Condenser.from_config(config, mock_llm_registry)
assert isinstance(condenser, CondenserPipeline) assert isinstance(condenser, CondenserPipeline)
assert len(condenser.condensers) == 3 assert len(condenser.condensers) == 3

View File

@@ -0,0 +1,490 @@
import base64
import pickle
from unittest.mock import patch
import pytest
from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
from openhands.llm.metrics import Metrics
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_file_store():
"""Create a mock file store for testing."""
return InMemoryFileStore({})
@pytest.fixture
def conversation_stats(mock_file_store):
"""Create a ConversationStats instance for testing."""
return ConversationStats(
file_store=mock_file_store,
conversation_id='test-conversation-id',
user_id='test-user-id',
)
@pytest.fixture
def mock_llm_registry():
"""Create a mock LLM registry that properly simulates LLM registration."""
config = OpenHandsConfig()
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
return registry
@pytest.fixture
def connected_registry_and_stats(mock_llm_registry, conversation_stats):
"""Connect the LLMRegistry and ConversationStats properly."""
# Subscribe to LLM registry events to track metrics
mock_llm_registry.subscribe(conversation_stats.register_llm)
return mock_llm_registry, conversation_stats
def test_conversation_stats_initialization(conversation_stats):
"""Test that ConversationStats initializes correctly."""
assert conversation_stats.conversation_id == 'test-conversation-id'
assert conversation_stats.user_id == 'test-user-id'
assert conversation_stats.service_to_metrics == {}
assert isinstance(conversation_stats.restored_metrics, dict)
def test_save_metrics(conversation_stats, mock_file_store):
"""Test that metrics are saved correctly."""
# Add a service with metrics
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.05)
conversation_stats.service_to_metrics[service_id] = metrics
# Save metrics
conversation_stats.save_metrics()
# Verify that metrics were saved to the file store
try:
# Verify the saved content can be decoded and unpickled
encoded = mock_file_store.read(conversation_stats.metrics_path)
pickled = base64.b64decode(encoded)
restored = pickle.loads(pickled)
assert service_id in restored
assert restored[service_id].accumulated_cost == 0.05
except FileNotFoundError:
pytest.fail(f'File not found: {conversation_stats.metrics_path}')
def test_maybe_restore_metrics(mock_file_store):
"""Test that metrics are restored correctly."""
# Create metrics to save
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.1)
service_to_metrics = {service_id: metrics}
# Serialize and save metrics
pickled = pickle.dumps(service_to_metrics)
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
# Create a new ConversationStats with pre-populated file store
conversation_id = 'test-conversation-id'
user_id = 'test-user-id'
# Get the correct path using the same function as ConversationStats
from openhands.storage.locations import get_conversation_stats_filename
metrics_path = get_conversation_stats_filename(conversation_id, user_id)
# Write to the correct path
mock_file_store.write(metrics_path, serialized_metrics)
# Create ConversationStats which should restore metrics
stats = ConversationStats(
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
)
# Verify metrics were restored
assert service_id in stats.restored_metrics
assert stats.restored_metrics[service_id].accumulated_cost == 0.1
def test_get_combined_metrics(conversation_stats):
"""Test that combined metrics are calculated correctly."""
# Add multiple services with metrics
service1 = 'service1'
metrics1 = Metrics(model_name='gpt-4')
metrics1.add_cost(0.05)
metrics1.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
service2 = 'service2'
metrics2 = Metrics(model_name='gpt-3.5')
metrics2.add_cost(0.02)
metrics2.add_token_usage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4000,
response_id='resp2',
)
conversation_stats.service_to_metrics[service1] = metrics1
conversation_stats.service_to_metrics[service2] = metrics2
# Get combined metrics
combined = conversation_stats.get_combined_metrics()
# Verify combined metrics
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
assert (
combined.accumulated_token_usage.context_window == 8000
) # max of 8000 and 4000
def test_get_metrics_for_service(conversation_stats):
"""Test that metrics for a specific service are retrieved correctly."""
# Add a service with metrics
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.05)
conversation_stats.service_to_metrics[service_id] = metrics
# Get metrics for the service
retrieved_metrics = conversation_stats.get_metrics_for_service(service_id)
# Verify metrics
assert retrieved_metrics.accumulated_cost == 0.05
assert retrieved_metrics is metrics # Should be the same object
# Test getting metrics for non-existent service
# Use a specific exception message pattern instead of a blind Exception
with pytest.raises(Exception, match='LLM service does not exist'):
conversation_stats.get_metrics_for_service('non-existent-service')
def test_register_llm_with_new_service(conversation_stats):
"""Test registering a new LLM service."""
# Create a real LLM instance with a mock config
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
llm = LLM(service_id='new-service', config=llm_config)
# Create a registry event
service_id = 'new-service'
event = RegistryEvent(llm=llm, service_id=service_id)
# Register the LLM
conversation_stats.register_llm(event)
# Verify the service was registered
assert service_id in conversation_stats.service_to_metrics
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
def test_register_llm_with_restored_metrics(conversation_stats):
"""Test registering an LLM service with restored metrics."""
# Create restored metrics
service_id = 'restored-service'
restored_metrics = Metrics(model_name='gpt-4')
restored_metrics.add_cost(0.1)
conversation_stats.restored_metrics = {service_id: restored_metrics}
# Create a real LLM instance with a mock config
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
llm = LLM(service_id=service_id, config=llm_config)
# Create a registry event
event = RegistryEvent(llm=llm, service_id=service_id)
# Register the LLM
conversation_stats.register_llm(event)
# Verify the service was registered with restored metrics
assert service_id in conversation_stats.service_to_metrics
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
assert llm.metrics.accumulated_cost == 0.1 # Restored cost
# Verify the specific service was removed from restored_metrics
assert service_id not in conversation_stats.restored_metrics
assert hasattr(
conversation_stats, 'restored_metrics'
) # The dict should still exist
def test_llm_registry_notifications(connected_registry_and_stats):
"""Test that LLM registry notifications update conversation stats."""
mock_llm_registry, conversation_stats = connected_registry_and_stats
# Create a new LLM through the registry
service_id = 'test-service'
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Get LLM from registry (this should trigger the notification)
llm = mock_llm_registry.get_llm(service_id, llm_config)
# Verify the service was registered in conversation stats
assert service_id in conversation_stats.service_to_metrics
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
# Add some metrics to the LLM
llm.metrics.add_cost(0.05)
llm.metrics.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
# Verify the metrics are reflected in conversation stats
assert conversation_stats.service_to_metrics[service_id].accumulated_cost == 0.05
assert (
conversation_stats.service_to_metrics[
service_id
].accumulated_token_usage.prompt_tokens
== 100
)
assert (
conversation_stats.service_to_metrics[
service_id
].accumulated_token_usage.completion_tokens
== 50
)
# Get combined metrics and verify
combined = conversation_stats.get_combined_metrics()
assert combined.accumulated_cost == 0.05
assert combined.accumulated_token_usage.prompt_tokens == 100
assert combined.accumulated_token_usage.completion_tokens == 50
def test_multiple_llm_services(connected_registry_and_stats):
"""Test tracking metrics for multiple LLM services."""
mock_llm_registry, conversation_stats = connected_registry_and_stats
# Create multiple LLMs through the registry
service1 = 'service1'
service2 = 'service2'
llm_config1 = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
llm_config2 = LLMConfig(
model='gpt-3.5-turbo',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Get LLMs from registry (this should trigger notifications)
llm1 = mock_llm_registry.get_llm(service1, llm_config1)
llm2 = mock_llm_registry.get_llm(service2, llm_config2)
# Add different metrics to each LLM
llm1.metrics.add_cost(0.05)
llm1.metrics.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
llm2.metrics.add_cost(0.02)
llm2.metrics.add_token_usage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4000,
response_id='resp2',
)
# Verify services were registered in conversation stats
assert service1 in conversation_stats.service_to_metrics
assert service2 in conversation_stats.service_to_metrics
# Verify individual metrics
assert conversation_stats.service_to_metrics[service1].accumulated_cost == 0.05
assert conversation_stats.service_to_metrics[service2].accumulated_cost == 0.02
# Get combined metrics and verify
combined = conversation_stats.get_combined_metrics()
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
assert (
combined.accumulated_token_usage.context_window == 8000
) # max of 8000 and 4000
def test_register_llm_with_multiple_restored_services_bug(conversation_stats):
"""Test that reproduces the bug where del self.restored_metrics deletes entire dict instead of specific service."""
# Create restored metrics for multiple services
service_id_1 = 'service-1'
service_id_2 = 'service-2'
restored_metrics_1 = Metrics(model_name='gpt-4')
restored_metrics_1.add_cost(0.1)
restored_metrics_2 = Metrics(model_name='gpt-3.5')
restored_metrics_2.add_cost(0.05)
# Set up restored metrics for both services
conversation_stats.restored_metrics = {
service_id_1: restored_metrics_1,
service_id_2: restored_metrics_2,
}
# Create LLM configs
llm_config_1 = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
llm_config_2 = LLMConfig(
model='gpt-3.5-turbo',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
# Register first LLM
llm_1 = LLM(service_id=service_id_1, config=llm_config_1)
event_1 = RegistryEvent(llm=llm_1, service_id=service_id_1)
conversation_stats.register_llm(event_1)
# Verify first service was registered with restored metrics
assert service_id_1 in conversation_stats.service_to_metrics
assert llm_1.metrics.accumulated_cost == 0.1
# After registering first service, restored_metrics should still contain service_id_2
assert service_id_2 in conversation_stats.restored_metrics
# Register second LLM - this should also work with restored metrics
llm_2 = LLM(service_id=service_id_2, config=llm_config_2)
event_2 = RegistryEvent(llm=llm_2, service_id=service_id_2)
conversation_stats.register_llm(event_2)
# Verify second service was registered with restored metrics
assert service_id_2 in conversation_stats.service_to_metrics
assert llm_2.metrics.accumulated_cost == 0.05
# After both services are registered, restored_metrics should be empty
assert len(conversation_stats.restored_metrics) == 0
def test_save_and_restore_workflow(mock_file_store):
"""Test the full workflow of saving and restoring metrics."""
# Create initial conversation stats
conversation_id = 'test-conversation-id'
user_id = 'test-user-id'
stats1 = ConversationStats(
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
)
# Add a service with metrics
service_id = 'test-service'
metrics = Metrics(model_name='gpt-4')
metrics.add_cost(0.05)
metrics.add_token_usage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=8000,
response_id='resp1',
)
stats1.service_to_metrics[service_id] = metrics
# Save metrics
stats1.save_metrics()
# Create a new conversation stats instance that should restore the metrics
stats2 = ConversationStats(
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
)
# Verify metrics were restored
assert service_id in stats2.restored_metrics
assert stats2.restored_metrics[service_id].accumulated_cost == 0.05
assert (
stats2.restored_metrics[service_id].accumulated_token_usage.prompt_tokens == 100
)
assert (
stats2.restored_metrics[service_id].accumulated_token_usage.completion_tokens
== 50
)
# Create a real LLM instance with a mock config
llm_config = LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)
# Patch the LLM class to avoid actual API calls
with patch('openhands.llm.llm.litellm_completion'):
llm = LLM(service_id=service_id, config=llm_config)
# Create a registry event
event = RegistryEvent(llm=llm, service_id=service_id)
# Register the LLM to trigger restoration
stats2.register_llm(event)
# Verify metrics were applied to the LLM
assert llm.metrics.accumulated_cost == 0.05
assert llm.metrics.accumulated_token_usage.prompt_tokens == 100
assert llm.metrics.accumulated_token_usage.completion_tokens == 50

View File

@@ -1,6 +1,6 @@
"""Tests for the conversation summary generator.""" """Tests for the conversation summary generator."""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock
import pytest import pytest
@@ -11,54 +11,50 @@ from openhands.utils.conversation_summary import generate_conversation_title
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_conversation_title_empty_message(): async def test_generate_conversation_title_empty_message():
"""Test that an empty message returns None.""" """Test that an empty message returns None."""
result = await generate_conversation_title('', MagicMock()) mock_llm_registry = MagicMock()
mock_llm_config = LLMConfig(model='test-model')
result = await generate_conversation_title('', mock_llm_config, mock_llm_registry)
assert result is None assert result is None
result = await generate_conversation_title(' ', MagicMock()) result = await generate_conversation_title(
' ', mock_llm_config, mock_llm_registry
)
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_conversation_title_success(): async def test_generate_conversation_title_success():
"""Test successful title generation.""" """Test successful title generation."""
# Create a proper mock response # Create a mock LLM registry that returns a title
mock_response = MagicMock() mock_llm_registry = MagicMock()
mock_response.choices = [MagicMock()] mock_llm_registry.request_extraneous_completion.return_value = 'Generated Title'
mock_response.choices[0].message.content = 'Generated Title'
# Create a mock LLM instance with a synchronous completion method mock_llm_config = LLMConfig(model='test-model')
mock_llm = MagicMock()
mock_llm.completion = MagicMock(return_value=mock_response)
# Patch the LLM class to return our mock
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
result = await generate_conversation_title( result = await generate_conversation_title(
'Can you help me with Python?', LLMConfig(model='test-model') 'Can you help me with Python?', mock_llm_config, mock_llm_registry
) )
assert result == 'Generated Title' assert result == 'Generated Title'
# Verify the mock was called with the expected arguments # Verify the mock was called with the expected arguments
mock_llm.completion.assert_called_once() mock_llm_registry.request_extraneous_completion.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_conversation_title_long_title(): async def test_generate_conversation_title_long_title():
"""Test that long titles are truncated.""" """Test that long titles are truncated."""
# Create a proper mock response with a long title # Create a mock LLM registry that returns a long title
mock_response = MagicMock() mock_llm_registry = MagicMock()
mock_response.choices = [MagicMock()] mock_llm_registry.request_extraneous_completion.return_value = 'This is a very long title that should be truncated because it exceeds the maximum length'
mock_response.choices[
0
].message.content = 'This is a very long title that should be truncated because it exceeds the maximum length'
# Create a mock LLM instance with a synchronous completion method mock_llm_config = LLMConfig(model='test-model')
mock_llm = MagicMock()
mock_llm.completion = MagicMock(return_value=mock_response)
# Patch the LLM class to return our mock
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
result = await generate_conversation_title( result = await generate_conversation_title(
'Can you help me with Python?', LLMConfig(model='test-model'), max_length=30 'Can you help me with Python?',
mock_llm_config,
mock_llm_registry,
max_length=30,
) )
# Verify the title is truncated correctly # Verify the title is truncated correctly
@@ -69,14 +65,16 @@ async def test_generate_conversation_title_long_title():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_conversation_title_exception(): async def test_generate_conversation_title_exception():
"""Test that exceptions are handled gracefully.""" """Test that exceptions are handled gracefully."""
# Create a mock LLM instance with a synchronous completion method that raises an exception # Create a mock LLM registry that raises an exception
mock_llm = MagicMock() mock_llm_registry = MagicMock()
mock_llm.completion = MagicMock(side_effect=Exception('Test error')) mock_llm_registry.request_extraneous_completion.side_effect = Exception(
'Test error'
)
mock_llm_config = LLMConfig(model='test-model')
# Patch the LLM class to return our mock
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
result = await generate_conversation_title( result = await generate_conversation_title(
'Can you help me with Python?', LLMConfig(model='test-model') 'Can you help me with Python?', mock_llm_config, mock_llm_registry
) )
# Verify that None is returned when an exception occurs # Verify that None is returned when an exception occurs

View File

@@ -4,6 +4,7 @@ import pytest
from openhands.core.config import OpenHandsConfig from openhands.core.config import OpenHandsConfig
from openhands.events import EventStream from openhands.events import EventStream
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
@@ -40,12 +41,17 @@ def event_stream():
return MagicMock(spec=EventStream) return MagicMock(spec=EventStream)
@pytest.fixture
def llm_registry():
return MagicMock(spec=LLMRegistry)
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers') @patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
def test_container_stopped_when_keep_runtime_alive_false( def test_container_stopped_when_keep_runtime_alive_false(
mock_stop_containers, mock_docker_client, config, event_stream mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
): ):
# Arrange # Arrange
runtime = DockerRuntime(config, event_stream, sid='test-sid') runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
runtime.container = mock_docker_client.containers.get.return_value runtime.container = mock_docker_client.containers.get.return_value
# Act # Act
@@ -57,11 +63,11 @@ def test_container_stopped_when_keep_runtime_alive_false(
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers') @patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
def test_container_not_stopped_when_keep_runtime_alive_true( def test_container_not_stopped_when_keep_runtime_alive_true(
mock_stop_containers, mock_docker_client, config, event_stream mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
): ):
# Arrange # Arrange
config.sandbox.keep_runtime_alive = True config.sandbox.keep_runtime_alive = True
runtime = DockerRuntime(config, event_stream, sid='test-sid') runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
runtime.container = mock_docker_client.containers.get.return_value runtime.container = mock_docker_client.containers.get.return_value
# Act # Act

View File

@@ -0,0 +1,178 @@
from __future__ import annotations
import unittest
from unittest.mock import MagicMock, patch
from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
class TestLLMRegistry(unittest.TestCase):
def setUp(self):
"""Set up test environment before each test."""
# Create a basic LLM config for testing
self.llm_config = LLMConfig(model='test-model')
# Create a basic OpenHands config for testing
self.config = OpenHandsConfig(
llms={'llm': self.llm_config}, default_agent='CodeActAgent'
)
# Create a registry for testing
self.registry = LLMRegistry(config=self.config)
def test_get_llm_creates_new_llm(self):
"""Test that get_llm creates a new LLM when service doesn't exist."""
service_id = 'test-service'
# Mock the _create_new_llm method to avoid actual LLM initialization
with patch.object(self.registry, '_create_new_llm') as mock_create:
mock_llm = MagicMock()
mock_llm.config = self.llm_config
mock_create.return_value = mock_llm
# Get LLM for the first time
llm = self.registry.get_llm(service_id, self.llm_config)
# Verify LLM was created and stored
self.assertEqual(llm, mock_llm)
mock_create.assert_called_once_with(
config=self.llm_config, service_id=service_id
)
def test_get_llm_returns_existing_llm(self):
"""Test that get_llm returns existing LLM when service already exists."""
service_id = 'test-service'
# Mock the _create_new_llm method to avoid actual LLM initialization
with patch.object(self.registry, '_create_new_llm') as mock_create:
mock_llm = MagicMock()
mock_llm.config = self.llm_config
mock_create.return_value = mock_llm
# Get LLM for the first time
llm1 = self.registry.get_llm(service_id, self.llm_config)
# Manually add to registry to simulate existing LLM
self.registry.service_to_llm[service_id] = mock_llm
# Get LLM for the second time - should return the same instance
llm2 = self.registry.get_llm(service_id, self.llm_config)
# Verify same LLM instance is returned
self.assertEqual(llm1, llm2)
self.assertEqual(llm1, mock_llm)
# Verify _create_new_llm was only called once
mock_create.assert_called_once()
def test_get_llm_with_different_config_raises_error(self):
"""Test that requesting same service ID with different config raises an error."""
service_id = 'test-service'
different_config = LLMConfig(model='different-model')
# Manually add an LLM to the registry to simulate existing service
mock_llm = MagicMock()
mock_llm.config = self.llm_config
self.registry.service_to_llm[service_id] = mock_llm
# Attempt to get LLM with different config should raise ValueError
with self.assertRaises(ValueError) as context:
self.registry.get_llm(service_id, different_config)
self.assertIn('Requesting same service ID', str(context.exception))
self.assertIn('with different config', str(context.exception))
def test_get_llm_without_config_raises_error(self):
"""Test that requesting new LLM without config raises an error."""
service_id = 'test-service'
# Attempt to get LLM without providing config should raise ValueError
with self.assertRaises(ValueError) as context:
self.registry.get_llm(service_id, None)
self.assertIn(
'Requesting new LLM without specifying LLM config', str(context.exception)
)
def test_request_extraneous_completion(self):
"""Test that requesting an extraneous completion creates a new LLM if needed."""
service_id = 'extraneous-service'
messages = [{'role': 'user', 'content': 'Hello, world!'}]
# Mock the _create_new_llm method to avoid actual LLM initialization
with patch.object(self.registry, '_create_new_llm') as mock_create:
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = ' Hello from the LLM! '
mock_llm.completion.return_value = mock_response
mock_create.return_value = mock_llm
# Mock the side effect to add the LLM to the registry
def side_effect(*args, **kwargs):
self.registry.service_to_llm[service_id] = mock_llm
return mock_llm
mock_create.side_effect = side_effect
# Request a completion
response = self.registry.request_extraneous_completion(
service_id=service_id,
llm_config=self.llm_config,
messages=messages,
)
# Verify the response (should be stripped)
self.assertEqual(response, 'Hello from the LLM!')
# Verify that _create_new_llm was called with correct parameters
mock_create.assert_called_once_with(
config=self.llm_config, service_id=service_id, with_listener=False
)
# Verify completion was called with correct messages
mock_llm.completion.assert_called_once_with(messages=messages)
def test_get_active_llm(self):
"""Test that get_active_llm returns the active agent LLM."""
active_llm = self.registry.get_active_llm()
self.assertEqual(active_llm, self.registry.active_agent_llm)
def test_subscribe_and_notify(self):
"""Test the subscription and notification system."""
events_received = []
def callback(event: RegistryEvent):
events_received.append(event)
# Subscribe to events
self.registry.subscribe(callback)
# Should receive notification for the active agent LLM
self.assertEqual(len(events_received), 1)
self.assertEqual(events_received[0].llm, self.registry.active_agent_llm)
self.assertEqual(
events_received[0].service_id, self.registry.active_agent_llm.service_id
)
# Test that the subscriber is set correctly
self.assertIsNotNone(self.registry.subscriber)
# Test notify method directly with a mock event
with patch.object(self.registry, 'subscriber') as mock_subscriber:
mock_event = MagicMock()
self.registry.notify(mock_event)
mock_subscriber.assert_called_once_with(mock_event)
def test_registry_has_unique_id(self):
"""Test that each registry instance has a unique ID."""
registry2 = LLMRegistry(config=self.config)
self.assertNotEqual(self.registry.registry_id, registry2.registry_id)
self.assertTrue(len(self.registry.registry_id) > 0)
self.assertTrue(len(registry2.registry_id) > 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -12,6 +12,8 @@ from openhands.core.config.mcp_config import (
MCPSSEServerConfig, MCPSSEServerConfig,
MCPStdioServerConfig, MCPStdioServerConfig,
) )
from openhands.llm.llm_registry import LLMRegistry
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.conversation_init_data import ConversationInitData from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.session.session import Session from openhands.server.session.session import Session
from openhands.storage.memory import InMemoryFileStore from openhands.storage.memory import InMemoryFileStore
@@ -428,6 +430,8 @@ async def test_session_preserves_env_mcp_config(monkeypatch):
file_store=InMemoryFileStore({}), file_store=InMemoryFileStore({}),
config=config, config=config,
sio=AsyncMock(), sio=AsyncMock(),
llm_registry=LLMRegistry(config=OpenHandsConfig()),
convo_stats=ConversationStats(None, 'test-sid', None),
) )
# Create empty settings # Create empty settings

View File

@@ -8,7 +8,8 @@ import pytest
from mcp import McpError from mcp import McpError
from openhands.controller.agent import Agent from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController, AgentState from openhands.controller.agent_controller import AgentController
from openhands.core.schema import AgentState
from openhands.events.action.mcp import MCPAction from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import SystemMessageAction from openhands.events.action.message import SystemMessageAction
from openhands.events.event import EventSource from openhands.events.event import EventSource
@@ -17,6 +18,8 @@ from openhands.events.stream import EventStream
from openhands.mcp.client import MCPClient from openhands.mcp.client import MCPClient
from openhands.mcp.tool import MCPClientTool from openhands.mcp.tool import MCPClientTool
from openhands.mcp.utils import call_tool_mcp from openhands.mcp.utils import call_tool_mcp
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
class MockConfig: class MockConfig:
@@ -34,6 +37,11 @@ class MockLLM:
self.config = MockConfig() self.config = MockConfig()
@pytest.fixture
def convo_stats():
return ConversationStats(None, 'convo-id', None)
class MockAgent(Agent): class MockAgent(Agent):
"""Mock agent for testing.""" """Mock agent for testing."""
@@ -53,7 +61,7 @@ class MockAgent(Agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mcp_tool_timeout_error_handling(): async def test_mcp_tool_timeout_error_handling(convo_stats):
"""Test that verifies MCP tool timeout errors are properly handled and returned as observations.""" """Test that verifies MCP tool timeout errors are properly handled and returned as observations."""
# Create a mock MCPClient # Create a mock MCPClient
mock_client = mock.MagicMock(spec=MCPClient) mock_client = mock.MagicMock(spec=MCPClient)
@@ -80,7 +88,7 @@ async def test_mcp_tool_timeout_error_handling():
mock_client.tool_map = {'test_tool': mock_tool} mock_client.tool_map = {'test_tool': mock_tool}
# Create a mock file store # Create a mock file store
mock_file_store = mock.MagicMock() mock_file_store = InMemoryFileStore({})
# Create a mock event stream # Create a mock event stream
event_stream = EventStream(sid='test-session', file_store=mock_file_store) event_stream = EventStream(sid='test-session', file_store=mock_file_store)
@@ -90,13 +98,12 @@ async def test_mcp_tool_timeout_error_handling():
# Create a mock agent controller # Create a mock agent controller
controller = AgentController( controller = AgentController(
sid='test-session',
file_store=mock_file_store,
user_id='test-user',
agent=agent, agent=agent,
event_stream=event_stream, event_stream=event_stream,
convo_stats=convo_stats,
iteration_delta=10, iteration_delta=10,
budget_per_task_delta=None, budget_per_task_delta=None,
sid='test-session',
) )
# Set up the agent state # Set up the agent state
@@ -143,7 +150,7 @@ async def test_mcp_tool_timeout_error_handling():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mcp_tool_timeout_agent_continuation(): async def test_mcp_tool_timeout_agent_continuation(convo_stats):
"""Test that verifies the agent can continue processing after an MCP tool timeout.""" """Test that verifies the agent can continue processing after an MCP tool timeout."""
# Create a mock MCPClient # Create a mock MCPClient
mock_client = mock.MagicMock(spec=MCPClient) mock_client = mock.MagicMock(spec=MCPClient)
@@ -170,7 +177,7 @@ async def test_mcp_tool_timeout_agent_continuation():
mock_client.tool_map = {'test_tool': mock_tool} mock_client.tool_map = {'test_tool': mock_tool}
# Create a mock file store # Create a mock file store
mock_file_store = mock.MagicMock() mock_file_store = InMemoryFileStore({})
# Create a mock event stream # Create a mock event stream
event_stream = EventStream(sid='test-session', file_store=mock_file_store) event_stream = EventStream(sid='test-session', file_store=mock_file_store)
@@ -180,13 +187,12 @@ async def test_mcp_tool_timeout_agent_continuation():
# Create a mock agent controller # Create a mock agent controller
controller = AgentController( controller = AgentController(
sid='test-session',
file_store=mock_file_store,
user_id='test-user',
agent=agent, agent=agent,
event_stream=event_stream, event_stream=event_stream,
convo_stats=convo_stats,
iteration_delta=10, iteration_delta=10,
budget_per_task_delta=None, budget_per_task_delta=None,
sid='test-session',
) )
# Set up the agent state # Set up the agent state

View File

@@ -21,11 +21,13 @@ from openhands.events.observation.agent import (
from openhands.events.serialization.observation import observation_from_dict from openhands.events.serialization.observation import observation_from_dict
from openhands.events.stream import EventStream from openhands.events.stream import EventStream
from openhands.llm import LLM from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry
from openhands.llm.metrics import Metrics from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import ( from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient, ActionExecutionClient,
) )
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.agent_session import AgentSession from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore from openhands.storage.memory import InMemoryFileStore
from openhands.utils.prompt import ( from openhands.utils.prompt import (
@@ -42,6 +44,12 @@ def file_store():
return InMemoryFileStore({}) return InMemoryFileStore({})
@pytest.fixture
def mock_llm_registry(file_store):
"""Create a mock LLMRegistry for testing."""
return MagicMock(spec=LLMRegistry)
@pytest.fixture @pytest.fixture
def event_stream(file_store): def event_stream(file_store):
"""Create a test event stream.""" """Create a test event stream."""
@@ -90,24 +98,29 @@ def mock_agent():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent): async def test_memory_on_event_exception_handling(
memory, event_stream, mock_agent, mock_llm_registry
):
"""Test that exceptions in Memory.on_event are properly handled via status callback.""" """Test that exceptions in Memory.on_event are properly handled via status callback."""
# Create a mock runtime # Create a mock runtime
runtime = MagicMock(spec=ActionExecutionClient) runtime = MagicMock(spec=ActionExecutionClient)
runtime.event_stream = event_stream runtime.event_stream = event_stream
# Mock Memory method to raise an exception # Mock Memory method to raise an exception
with patch.object( with (
patch.object(
memory, '_on_workspace_context_recall', side_effect=Exception('Test error') memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
),
patch('openhands.core.main.create_agent', return_value=mock_agent),
): ):
state = await run_controller( state = await run_controller(
config=OpenHandsConfig(), config=OpenHandsConfig(),
initial_user_action=MessageAction(content='Test message'), initial_user_action=MessageAction(content='Test message'),
runtime=runtime, runtime=runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=memory, memory=memory,
llm_registry=mock_llm_registry,
) )
# Verify that the controller's last error was set # Verify that the controller's last error was set
@@ -118,7 +131,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_memory_on_workspace_context_recall_exception_handling( async def test_memory_on_workspace_context_recall_exception_handling(
memory, event_stream, mock_agent memory, event_stream, mock_agent, mock_llm_registry
): ):
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback.""" """Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
# Create a mock runtime # Create a mock runtime
@@ -126,19 +139,22 @@ async def test_memory_on_workspace_context_recall_exception_handling(
runtime.event_stream = event_stream runtime.event_stream = event_stream
# Mock Memory._on_workspace_context_recall to raise an exception # Mock Memory._on_workspace_context_recall to raise an exception
with patch.object( with (
patch.object(
memory, memory,
'_find_microagent_knowledge', '_find_microagent_knowledge',
side_effect=Exception('Test error from _find_microagent_knowledge'), side_effect=Exception('Test error from _find_microagent_knowledge'),
),
patch('openhands.core.main.create_agent', return_value=mock_agent),
): ):
state = await run_controller( state = await run_controller(
config=OpenHandsConfig(), config=OpenHandsConfig(),
initial_user_action=MessageAction(content='Test message'), initial_user_action=MessageAction(content='Test message'),
runtime=runtime, runtime=runtime,
sid='test', sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat', fake_user_response_fn=lambda _: 'repeat',
memory=memory, memory=memory,
llm_registry=mock_llm_registry,
) )
# Verify that the controller's last error was set # Verify that the controller's last error was set
@@ -593,12 +609,14 @@ REPOSITORY INSTRUCTIONS: This is the second test repository.
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conversation_instructions_plumbed_to_memory( async def test_conversation_instructions_plumbed_to_memory(
mock_agent, event_stream, file_store mock_agent, event_stream, file_store, mock_llm_registry
): ):
# Setup # Setup
session = AgentSession( session = AgentSession(
sid='test-session', sid='test-session',
file_store=file_store, file_store=file_store,
llm_registry=mock_llm_registry,
convo_stats=ConversationStats(file_store, 'test-session', None),
) )
# Create a mock runtime and set it up # Create a mock runtime and set it up

View File

@@ -3,26 +3,30 @@ from litellm import ModelResponse
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig, LLMConfig from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.events.action import MessageAction from openhands.events.action import MessageAction
from openhands.llm.llm import LLM from openhands.llm.llm_registry import LLMRegistry
@pytest.fixture @pytest.fixture
def mock_llm(): def llm_config():
llm = LLM( return LLMConfig(
LLMConfig(
model='claude-3-5-sonnet-20241022', model='claude-3-5-sonnet-20241022',
api_key='fake', api_key='fake',
caching_prompt=True, caching_prompt=True,
) )
)
return llm
@pytest.fixture @pytest.fixture
def codeact_agent(mock_llm): def llm_registry():
registry = LLMRegistry(config=OpenHandsConfig())
return registry
@pytest.fixture
def codeact_agent(llm_registry):
config = AgentConfig() config = AgentConfig()
agent = CodeActAgent(mock_llm, config) agent = CodeActAgent(config, llm_registry)
return agent return agent

View File

@@ -12,14 +12,26 @@ from openhands.events.observation import NullObservation, Observation
from openhands.events.stream import EventStream from openhands.events.stream import EventStream
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError, Repository from openhands.integrations.service_types import AuthenticationError, Repository
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.storage import get_file_store from openhands.storage import get_file_store
class TestRuntime(Runtime): class MockRuntime(Runtime):
"""A concrete implementation of Runtime for testing""" """A concrete implementation of Runtime for testing"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Ensure llm_registry is provided if not already in kwargs
if 'llm_registry' not in kwargs and len(args) < 3:
# Create a mock LLMRegistry if not provided
config = (
kwargs.get('config')
if 'config' in kwargs
else args[0]
if args
else OpenHandsConfig()
)
kwargs['llm_registry'] = LLMRegistry(config=config)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.run_action_calls = [] self.run_action_calls = []
self._execute_shell_fn_git_handler = MagicMock( self._execute_shell_fn_git_handler = MagicMock(
@@ -89,9 +101,11 @@ def runtime(temp_dir):
) )
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( llm_registry = LLMRegistry(config=config)
runtime = MockRuntime(
config=config, config=config,
event_stream=event_stream, event_stream=event_stream,
llm_registry=llm_registry,
sid='test', sid='test',
user_id='test_user', user_id='test_user',
git_provider_tokens=git_provider_tokens, git_provider_tokens=git_provider_tokens,
@@ -119,7 +133,7 @@ async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
config = OpenHandsConfig() config = OpenHandsConfig()
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime(config=config, event_stream=event_stream, sid='test') runtime = MockRuntime(config=config, event_stream=event_stream, sid='test')
# Create a command that would normally trigger token export # Create a command that would normally trigger token export
cmd = CmdRunAction(command='echo $GITHUB_TOKEN') cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
@@ -137,7 +151,7 @@ async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir):
config = OpenHandsConfig() config = OpenHandsConfig()
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user' config=config, event_stream=event_stream, sid='test', user_id='test_user'
) )
@@ -177,7 +191,7 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
) )
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, config=config,
event_stream=event_stream, event_stream=event_stream,
sid='test', sid='test',
@@ -225,7 +239,7 @@ async def test_clone_or_init_repo_no_repo_init_git_in_empty_workspace(temp_dir):
config.init_git_in_empty_workspace = True config.init_git_in_empty_workspace = True
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id=None config=config, event_stream=event_stream, sid='test', user_id=None
) )
@@ -249,7 +263,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_di
config.workspace_base = '/some/path' # Set workspace_base config.workspace_base = '/some/path' # Set workspace_base
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id=None config=config, event_stream=event_stream, sid='test', user_id=None
) )
@@ -267,7 +281,7 @@ async def test_clone_or_init_repo_auth_error(temp_dir):
config = OpenHandsConfig() config = OpenHandsConfig()
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user' config=config, event_stream=event_stream, sid='test', user_id='test_user'
) )
@@ -298,7 +312,7 @@ async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))} {ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
) )
runtime = TestRuntime( runtime = MockRuntime(
config=config, config=config,
event_stream=event_stream, event_stream=event_stream,
sid='test', sid='test',
@@ -336,7 +350,7 @@ async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user' config=config, event_stream=event_stream, sid='test', user_id='test_user'
) )
@@ -371,7 +385,7 @@ async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))} {ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
) )
runtime = TestRuntime( runtime = MockRuntime(
config=config, config=config,
event_stream=event_stream, event_stream=event_stream,
sid='test', sid='test',
@@ -410,7 +424,7 @@ async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
file_store = get_file_store('local', temp_dir) file_store = get_file_store('local', temp_dir)
event_stream = EventStream('abc', file_store) event_stream = EventStream('abc', file_store)
runtime = TestRuntime( runtime = MockRuntime(
config=config, event_stream=event_stream, sid='test', user_id='test_user' config=config, event_stream=event_stream, sid='test', user_id='test_user'
) )

View File

@@ -9,10 +9,12 @@ import pytest
from openhands.core.config import OpenHandsConfig, SandboxConfig from openhands.core.config import OpenHandsConfig, SandboxConfig
from openhands.events import EventStream from openhands.events import EventStream
from openhands.integrations.service_types import ProviderType, Repository from openhands.integrations.service_types import ProviderType, Repository
from openhands.llm.llm_registry import LLMRegistry
from openhands.microagent.microagent import ( from openhands.microagent.microagent import (
RepoMicroagent, RepoMicroagent,
) )
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.storage import get_file_store
class MockRuntime(Runtime): class MockRuntime(Runtime):
@@ -24,12 +26,21 @@ class MockRuntime(Runtime):
config.workspace_mount_path_in_sandbox = str(workspace_root) config.workspace_mount_path_in_sandbox = str(workspace_root)
config.sandbox = SandboxConfig() config.sandbox = SandboxConfig()
# Create a mock event stream # Create a mock event stream and file store
file_store = get_file_store('local', str(workspace_root))
event_stream = MagicMock(spec=EventStream) event_stream = MagicMock(spec=EventStream)
event_stream.file_store = file_store
# Create a mock LLM registry
llm_registry = LLMRegistry(config)
# Initialize the parent class properly # Initialize the parent class properly
super().__init__( super().__init__(
config=config, event_stream=event_stream, sid='test', git_provider_tokens={} config=config,
event_stream=event_stream,
llm_registry=llm_registry,
sid='test',
git_provider_tokens={},
) )
self._workspace_root = workspace_root self._workspace_root = workspace_root

View File

@@ -595,7 +595,7 @@ async def test_check_usertask(
analyzer = InvariantAnalyzer(event_stream) analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_appropriate}}]} mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
mock_litellm_completion.return_value = mock_response mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config) analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
analyzer.check_browsing_alignment = True analyzer.check_browsing_alignment = True
data = [ data = [
(MessageAction(usertask), EventSource.USER), (MessageAction(usertask), EventSource.USER),
@@ -657,7 +657,7 @@ async def test_check_fillaction(
analyzer = InvariantAnalyzer(event_stream) analyzer = InvariantAnalyzer(event_stream)
mock_response = {'choices': [{'message': {'content': is_harmful}}]} mock_response = {'choices': [{'message': {'content': is_harmful}}]}
mock_litellm_completion.return_value = mock_response mock_litellm_completion.return_value = mock_response
analyzer.guardrail_llm = LLM(config=default_config) analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
analyzer.check_browsing_alignment = True analyzer.check_browsing_alignment = True
data = [ data = [
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT), (BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),

View File

@@ -7,7 +7,9 @@ from litellm.exceptions import (
from openhands.core.config.llm_config import LLMConfig from openhands.core.config.llm_config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.llm.llm_registry import LLMRegistry
from openhands.runtime.runtime_status import RuntimeStatus from openhands.runtime.runtime_status import RuntimeStatus
from openhands.server.services.conversation_stats import ConversationStats
from openhands.server.session.session import Session from openhands.server.session.session import Session
from openhands.storage.memory import InMemoryFileStore from openhands.storage.memory import InMemoryFileStore
@@ -33,10 +35,28 @@ def default_llm_config():
) )
@pytest.fixture
def llm_registry():
config = OpenHandsConfig()
return LLMRegistry(config=config)
@pytest.fixture
def conversation_stats():
file_store = InMemoryFileStore({})
return ConversationStats(
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('openhands.llm.llm.litellm_completion') @patch('openhands.llm.llm.litellm_completion')
async def test_notify_on_llm_retry( async def test_notify_on_llm_retry(
mock_litellm_completion, mock_sio, default_llm_config mock_litellm_completion,
mock_sio,
default_llm_config,
llm_registry,
conversation_stats,
): ):
config = OpenHandsConfig() config = OpenHandsConfig()
config.set_llm_config(default_llm_config) config.set_llm_config(default_llm_config)
@@ -44,6 +64,8 @@ async def test_notify_on_llm_retry(
sid='..sid..', sid='..sid..',
file_store=InMemoryFileStore({}), file_store=InMemoryFileStore({}),
config=config, config=config,
llm_registry=llm_registry,
convo_stats=conversation_stats,
sio=mock_sio, sio=mock_sio,
user_id='..uid..', user_id='..uid..',
) )
@@ -56,7 +78,15 @@ async def test_notify_on_llm_retry(
), ),
{'choices': [{'message': {'content': 'Retry successful'}}]}, {'choices': [{'message': {'content': 'Retry successful'}}]},
] ]
llm = session._create_llm('..cls..')
# Set the retry listener on the registry
llm_registry.retry_listner = session._notify_on_llm_retry
# Create an LLM through the registry
llm = llm_registry.get_llm(
service_id='test_service',
config=default_llm_config,
)
llm.completion( llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}], messages=[{'role': 'user', 'content': 'Hello!'}],