mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
[Refactor]: Add LLMRegistry for llm services (#9589)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Graham Neubig <neubig@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@ from openhands.events.action import (
|
|||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.observation import BrowserOutputObservation
|
from openhands.events.observation import BrowserOutputObservation
|
||||||
from openhands.events.observation.observation import Observation
|
from openhands.events.observation.observation import Observation
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.plugins import (
|
from openhands.runtime.plugins import (
|
||||||
PluginRequirement,
|
PluginRequirement,
|
||||||
)
|
)
|
||||||
@@ -102,15 +102,15 @@ class BrowsingAgent(Agent):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm: LLM,
|
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new instance of the BrowsingAgent class.
|
"""Initializes a new instance of the BrowsingAgent class.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- llm (LLM): The llm to be used by this agent
|
- llm (LLM): The llm to be used by this agent
|
||||||
"""
|
"""
|
||||||
super().__init__(llm, config)
|
super().__init__(config, llm_registry)
|
||||||
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
|
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
|
||||||
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
|
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
|
||||||
action_subsets = ['chat', 'bid']
|
action_subsets = ['chat', 'bid']
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import sys
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm import ChatCompletionToolParam
|
from litellm import ChatCompletionToolParam
|
||||||
|
|
||||||
@@ -32,7 +34,6 @@ from openhands.core.logger import openhands_logger as logger
|
|||||||
from openhands.core.message import Message
|
from openhands.core.message import Message
|
||||||
from openhands.events.action import AgentFinishAction, MessageAction
|
from openhands.events.action import AgentFinishAction, MessageAction
|
||||||
from openhands.events.event import Event
|
from openhands.events.event import Event
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
from openhands.llm.llm_utils import check_tools
|
from openhands.llm.llm_utils import check_tools
|
||||||
from openhands.memory.condenser import Condenser
|
from openhands.memory.condenser import Condenser
|
||||||
from openhands.memory.condenser.condenser import Condensation, View
|
from openhands.memory.condenser.condenser import Condensation, View
|
||||||
@@ -74,18 +75,13 @@ class CodeActAgent(Agent):
|
|||||||
JupyterRequirement(),
|
JupyterRequirement(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
|
||||||
self,
|
|
||||||
llm: LLM,
|
|
||||||
config: AgentConfig,
|
|
||||||
) -> None:
|
|
||||||
"""Initializes a new instance of the CodeActAgent class.
|
"""Initializes a new instance of the CodeActAgent class.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- llm (LLM): The llm to be used by this agent
|
|
||||||
- config (AgentConfig): The configuration for this agent
|
- config (AgentConfig): The configuration for this agent
|
||||||
"""
|
"""
|
||||||
super().__init__(llm, config)
|
super().__init__(config, llm_registry)
|
||||||
self.pending_actions: deque['Action'] = deque()
|
self.pending_actions: deque['Action'] = deque()
|
||||||
self.reset()
|
self.reset()
|
||||||
self.tools = self._get_tools()
|
self.tools = self._get_tools()
|
||||||
@@ -93,7 +89,7 @@ class CodeActAgent(Agent):
|
|||||||
# Create a ConversationMemory instance
|
# Create a ConversationMemory instance
|
||||||
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
||||||
|
|
||||||
self.condenser = Condenser.from_config(self.config.condenser)
|
self.condenser = Condenser.from_config(self.config.condenser, llm_registry)
|
||||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from openhands.events.observation import (
|
|||||||
Observation,
|
Observation,
|
||||||
)
|
)
|
||||||
from openhands.events.serialization.event import event_to_dict
|
from openhands.events.serialization.event import event_to_dict
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
|
||||||
"""
|
"""
|
||||||
FIXME: There are a few problems this surfaced
|
FIXME: There are a few problems this surfaced
|
||||||
@@ -42,8 +42,12 @@ class DummyAgent(Agent):
|
|||||||
without making any LLM calls.
|
without making any LLM calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm: LLM, config: AgentConfig):
|
def __init__(
|
||||||
super().__init__(llm, config)
|
self,
|
||||||
|
config: AgentConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
|
):
|
||||||
|
super().__init__(config, llm_registry)
|
||||||
self.steps: list[ActionObs] = [
|
self.steps: list[ActionObs] = [
|
||||||
{
|
{
|
||||||
'action': MessageAction('Time to get started!'),
|
'action': MessageAction('Time to get started!'),
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import openhands.agenthub.loc_agent.function_calling as locagent_function_callin
|
|||||||
from openhands.agenthub.codeact_agent import CodeActAgent
|
from openhands.agenthub.codeact_agent import CodeActAgent
|
||||||
from openhands.core.config import AgentConfig
|
from openhands.core.config import AgentConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from openhands.events.action import Action
|
from openhands.events.action import Action
|
||||||
@@ -16,8 +16,8 @@ class LocAgent(CodeActAgent):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm: LLM,
|
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new instance of the LocAgent class.
|
"""Initializes a new instance of the LocAgent class.
|
||||||
|
|
||||||
@@ -25,7 +25,8 @@ class LocAgent(CodeActAgent):
|
|||||||
- llm (LLM): The llm to be used by this agent
|
- llm (LLM): The llm to be used by this agent
|
||||||
- config (AgentConfig): The configuration for the agent
|
- config (AgentConfig): The configuration for the agent
|
||||||
"""
|
"""
|
||||||
super().__init__(llm, config)
|
|
||||||
|
super().__init__(config, llm_registry)
|
||||||
|
|
||||||
self.tools = locagent_function_calling.get_tools()
|
self.tools = locagent_function_calling.get_tools()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm import ChatCompletionToolParam
|
from litellm import ChatCompletionToolParam
|
||||||
|
|
||||||
@@ -15,7 +17,6 @@ from openhands.agenthub.readonly_agent import (
|
|||||||
)
|
)
|
||||||
from openhands.core.config import AgentConfig
|
from openhands.core.config import AgentConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
from openhands.utils.prompt import PromptManager
|
from openhands.utils.prompt import PromptManager
|
||||||
|
|
||||||
|
|
||||||
@@ -37,17 +38,16 @@ class ReadOnlyAgent(CodeActAgent):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm: LLM,
|
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new instance of the ReadOnlyAgent class.
|
"""Initializes a new instance of the ReadOnlyAgent class.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- llm (LLM): The llm to be used by this agent
|
|
||||||
- config (AgentConfig): The configuration for this agent
|
- config (AgentConfig): The configuration for this agent
|
||||||
"""
|
"""
|
||||||
# Initialize the CodeActAgent class; some of it is overridden with class methods
|
# Initialize the CodeActAgent class; some of it is overridden with class methods
|
||||||
super().__init__(llm, config)
|
super().__init__(config, llm_registry)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}'
|
f'TOOLS loaded for ReadOnlyAgent: {", ".join([tool.get("function").get("name") for tool in self.tools])}'
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from openhands.events.action import (
|
|||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.observation import BrowserOutputObservation
|
from openhands.events.observation import BrowserOutputObservation
|
||||||
from openhands.events.observation.observation import Observation
|
from openhands.events.observation.observation import Observation
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.plugins import (
|
from openhands.runtime.plugins import (
|
||||||
PluginRequirement,
|
PluginRequirement,
|
||||||
)
|
)
|
||||||
@@ -127,17 +127,13 @@ class VisualBrowsingAgent(Agent):
|
|||||||
sandbox_plugins: list[PluginRequirement] = []
|
sandbox_plugins: list[PluginRequirement] = []
|
||||||
response_parser = BrowsingResponseParser()
|
response_parser = BrowsingResponseParser()
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config: AgentConfig, llm_registry: LLMRegistry) -> None:
|
||||||
self,
|
|
||||||
llm: LLM,
|
|
||||||
config: AgentConfig,
|
|
||||||
) -> None:
|
|
||||||
"""Initializes a new instance of the VisualBrowsingAgent class.
|
"""Initializes a new instance of the VisualBrowsingAgent class.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- llm (LLM): The llm to be used by this agent
|
- llm (LLM): The llm to be used by this agent
|
||||||
"""
|
"""
|
||||||
super().__init__(llm, config)
|
super().__init__(config, llm_registry)
|
||||||
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
|
# define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
|
||||||
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
|
# see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
|
||||||
action_subsets = [
|
action_subsets = [
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ from openhands.microagent.microagent import BaseMicroagent
|
|||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.storage.settings.file_settings_store import FileSettingsStore
|
from openhands.storage.settings.file_settings_store import FileSettingsStore
|
||||||
|
from openhands.utils.utils import create_registry_and_convo_stats
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_session(
|
async def cleanup_session(
|
||||||
@@ -147,9 +148,16 @@ async def run_session(
|
|||||||
None, display_initialization_animation, 'Initializing...', is_loaded
|
None, display_initialization_animation, 'Initializing...', is_loaded
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = create_agent(config)
|
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||||
|
config,
|
||||||
|
sid,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent(config, llm_registry)
|
||||||
runtime = create_runtime(
|
runtime = create_runtime(
|
||||||
config,
|
config,
|
||||||
|
llm_registry,
|
||||||
sid=sid,
|
sid=sid,
|
||||||
headless_mode=True,
|
headless_mode=True,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
@@ -161,7 +169,7 @@ async def run_session(
|
|||||||
|
|
||||||
runtime.subscribe_to_shell_stream(stream_to_console)
|
runtime.subscribe_to_shell_stream(stream_to_console)
|
||||||
|
|
||||||
controller, initial_state = create_controller(agent, runtime, config)
|
controller, initial_state = create_controller(agent, runtime, config, convo_stats)
|
||||||
|
|
||||||
event_stream = runtime.event_stream
|
event_stream = runtime.event_stream
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.events.action import Action
|
from openhands.events.action import Action
|
||||||
@@ -17,7 +19,6 @@ from openhands.core.exceptions import (
|
|||||||
)
|
)
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
from openhands.runtime.plugins import PluginRequirement
|
from openhands.runtime.plugins import PluginRequirement
|
||||||
|
|
||||||
|
|
||||||
@@ -38,10 +39,11 @@ class Agent(ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm: LLM,
|
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
):
|
):
|
||||||
self.llm = llm
|
self.llm = llm_registry.get_llm_from_agent_config('agent', config)
|
||||||
|
self.llm_registry = llm_registry
|
||||||
self.config = config
|
self.config = config
|
||||||
self._complete = False
|
self._complete = False
|
||||||
self._prompt_manager: 'PromptManager' | None = None
|
self._prompt_manager: 'PromptManager' | None = None
|
||||||
|
|||||||
@@ -73,9 +73,9 @@ from openhands.events.observation import (
|
|||||||
Observation,
|
Observation,
|
||||||
)
|
)
|
||||||
from openhands.events.serialization.event import truncate_content
|
from openhands.events.serialization.event import truncate_content
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.llm.metrics import Metrics
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage.files import FileStore
|
from openhands.storage.files import FileStore
|
||||||
|
|
||||||
# note: RESUME is only available on web GUI
|
# note: RESUME is only available on web GUI
|
||||||
@@ -109,6 +109,7 @@ class AgentController:
|
|||||||
self,
|
self,
|
||||||
agent: Agent,
|
agent: Agent,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
convo_stats: ConversationStats,
|
||||||
iteration_delta: int,
|
iteration_delta: int,
|
||||||
budget_per_task_delta: float | None = None,
|
budget_per_task_delta: float | None = None,
|
||||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||||
@@ -148,6 +149,7 @@ class AgentController:
|
|||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.headless_mode = headless_mode
|
self.headless_mode = headless_mode
|
||||||
self.is_delegate = is_delegate
|
self.is_delegate = is_delegate
|
||||||
|
self.convo_stats = convo_stats
|
||||||
|
|
||||||
# the event stream must be set before maybe subscribing to it
|
# the event stream must be set before maybe subscribing to it
|
||||||
self.event_stream = event_stream
|
self.event_stream = event_stream
|
||||||
@@ -163,6 +165,7 @@ class AgentController:
|
|||||||
# state from the previous session, state from a parent agent, or a fresh state
|
# state from the previous session, state from a parent agent, or a fresh state
|
||||||
self.set_initial_state(
|
self.set_initial_state(
|
||||||
state=initial_state,
|
state=initial_state,
|
||||||
|
convo_stats=convo_stats,
|
||||||
max_iterations=iteration_delta,
|
max_iterations=iteration_delta,
|
||||||
max_budget_per_task=budget_per_task_delta,
|
max_budget_per_task=budget_per_task_delta,
|
||||||
confirmation_mode=confirmation_mode,
|
confirmation_mode=confirmation_mode,
|
||||||
@@ -477,11 +480,6 @@ class AgentController:
|
|||||||
log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'}
|
log_level, str(observation_to_print), extra={'msg_type': 'OBSERVATION'}
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: these metrics come from the draft editor, and they get accumulated into controller's state metrics and the agent's llm metrics
|
|
||||||
# In the future, we should have a more principled way to sharing metrics across all LLM instances for a given conversation
|
|
||||||
if observation.llm_metrics is not None:
|
|
||||||
self.state_tracker.merge_metrics(observation.llm_metrics)
|
|
||||||
|
|
||||||
# this happens for runnable actions and microagent actions
|
# this happens for runnable actions and microagent actions
|
||||||
if self._pending_action and self._pending_action.id == observation.cause:
|
if self._pending_action and self._pending_action.id == observation.cause:
|
||||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||||
@@ -657,14 +655,10 @@ class AgentController:
|
|||||||
"""
|
"""
|
||||||
agent_cls: type[Agent] = Agent.get_cls(action.agent)
|
agent_cls: type[Agent] = Agent.get_cls(action.agent)
|
||||||
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
agent_config = self.agent_configs.get(action.agent, self.agent.config)
|
||||||
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
|
|
||||||
# Make sure metrics are shared between parent and child for global accumulation
|
# Make sure metrics are shared between parent and child for global accumulation
|
||||||
llm = LLM(
|
delegate_agent = agent_cls(
|
||||||
config=llm_config,
|
config=agent_config, llm_registry=self.agent.llm_registry
|
||||||
retry_listener=self.agent.llm.retry_listener,
|
|
||||||
metrics=self.state.metrics,
|
|
||||||
)
|
)
|
||||||
delegate_agent = agent_cls(llm=llm, config=agent_config)
|
|
||||||
|
|
||||||
# Take a snapshot of the current metrics before starting the delegate
|
# Take a snapshot of the current metrics before starting the delegate
|
||||||
state = State(
|
state = State(
|
||||||
@@ -683,7 +677,7 @@ class AgentController:
|
|||||||
)
|
)
|
||||||
self.log(
|
self.log(
|
||||||
'debug',
|
'debug',
|
||||||
f'start delegate, creating agent {delegate_agent.name} using LLM {llm}',
|
f'start delegate, creating agent {delegate_agent.name}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the delegate with is_delegate=True so it does NOT subscribe directly
|
# Create the delegate with is_delegate=True so it does NOT subscribe directly
|
||||||
@@ -693,6 +687,7 @@ class AgentController:
|
|||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
agent=delegate_agent,
|
agent=delegate_agent,
|
||||||
event_stream=self.event_stream,
|
event_stream=self.event_stream,
|
||||||
|
convo_stats=self.convo_stats,
|
||||||
iteration_delta=self._initial_max_iterations,
|
iteration_delta=self._initial_max_iterations,
|
||||||
budget_per_task_delta=self._initial_max_budget_per_task,
|
budget_per_task_delta=self._initial_max_budget_per_task,
|
||||||
agent_to_llm_config=self.agent_to_llm_config,
|
agent_to_llm_config=self.agent_to_llm_config,
|
||||||
@@ -795,13 +790,8 @@ class AgentController:
|
|||||||
extra={'msg_type': 'STEP'},
|
extra={'msg_type': 'STEP'},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure budget control flag is synchronized with the latest metrics.
|
# Synchronize spend across all llm services with the budget flag
|
||||||
# In the future, we should centralized the use of one LLM object per conversation.
|
|
||||||
# This will help us unify the cost for auto generating titles, running the condensor, etc.
|
|
||||||
# Before many microservices will touh the same llm cost field, we should sync with the budget flag for the controller
|
|
||||||
# and check that we haven't exceeded budget BEFORE executing an agent step.
|
|
||||||
self.state_tracker.sync_budget_flag_with_metrics()
|
self.state_tracker.sync_budget_flag_with_metrics()
|
||||||
|
|
||||||
if self._is_stuck():
|
if self._is_stuck():
|
||||||
await self._react_to_exception(
|
await self._react_to_exception(
|
||||||
AgentStuckInLoopError('Agent got stuck in a loop')
|
AgentStuckInLoopError('Agent got stuck in a loop')
|
||||||
@@ -961,14 +951,15 @@ class AgentController:
|
|||||||
def set_initial_state(
|
def set_initial_state(
|
||||||
self,
|
self,
|
||||||
state: State | None,
|
state: State | None,
|
||||||
|
convo_stats: ConversationStats,
|
||||||
max_iterations: int,
|
max_iterations: int,
|
||||||
max_budget_per_task: float | None,
|
max_budget_per_task: float | None,
|
||||||
confirmation_mode: bool = False,
|
confirmation_mode: bool = False,
|
||||||
):
|
):
|
||||||
self.state_tracker.set_initial_state(
|
self.state_tracker.set_initial_state(
|
||||||
self.id,
|
self.id,
|
||||||
self.agent,
|
|
||||||
state,
|
state,
|
||||||
|
convo_stats,
|
||||||
max_iterations,
|
max_iterations,
|
||||||
max_budget_per_task,
|
max_budget_per_task,
|
||||||
confirmation_mode,
|
confirmation_mode,
|
||||||
@@ -1009,37 +1000,20 @@ class AgentController:
|
|||||||
action: The action to attach metrics to
|
action: The action to attach metrics to
|
||||||
"""
|
"""
|
||||||
# Get metrics from agent LLM
|
# Get metrics from agent LLM
|
||||||
agent_metrics = self.state.metrics
|
metrics = self.convo_stats.get_combined_metrics()
|
||||||
|
|
||||||
# Get metrics from condenser LLM if it exists
|
# Create a clean copy with only the fields we want to keep
|
||||||
condenser_metrics: Metrics | None = None
|
clean_metrics = Metrics()
|
||||||
if hasattr(self.agent, 'condenser') and hasattr(self.agent.condenser, 'llm'):
|
clean_metrics.accumulated_cost = metrics.accumulated_cost
|
||||||
condenser_metrics = self.agent.condenser.llm.metrics
|
clean_metrics._accumulated_token_usage = copy.deepcopy(
|
||||||
|
metrics.accumulated_token_usage
|
||||||
# Create a new minimal metrics object with just what the frontend needs
|
)
|
||||||
metrics = Metrics(model_name=agent_metrics.model_name)
|
|
||||||
|
|
||||||
# Set accumulated cost (sum of agent and condenser costs)
|
|
||||||
metrics.accumulated_cost = agent_metrics.accumulated_cost
|
|
||||||
if condenser_metrics:
|
|
||||||
metrics.accumulated_cost += condenser_metrics.accumulated_cost
|
|
||||||
|
|
||||||
# Add max_budget_per_task to metrics
|
# Add max_budget_per_task to metrics
|
||||||
if self.state.budget_flag:
|
if self.state.budget_flag:
|
||||||
metrics.max_budget_per_task = self.state.budget_flag.max_value
|
clean_metrics.max_budget_per_task = self.state.budget_flag.max_value
|
||||||
|
|
||||||
# Set accumulated token usage (sum of agent and condenser token usage)
|
action.llm_metrics = clean_metrics
|
||||||
# Use a deep copy to ensure we don't modify the original object
|
|
||||||
metrics._accumulated_token_usage = (
|
|
||||||
agent_metrics.accumulated_token_usage.model_copy(deep=True)
|
|
||||||
)
|
|
||||||
if condenser_metrics:
|
|
||||||
metrics._accumulated_token_usage = (
|
|
||||||
metrics._accumulated_token_usage
|
|
||||||
+ condenser_metrics.accumulated_token_usage
|
|
||||||
)
|
|
||||||
|
|
||||||
action.llm_metrics = metrics
|
|
||||||
|
|
||||||
# Log the metrics information for debugging
|
# Log the metrics information for debugging
|
||||||
# Get the latest usage directly from the agent's metrics
|
# Get the latest usage directly from the agent's metrics
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from openhands.events.action.agent import AgentFinishAction
|
|||||||
from openhands.events.event import Event, EventSource
|
from openhands.events.event import Event, EventSource
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.llm.metrics import Metrics
|
||||||
from openhands.memory.view import View
|
from openhands.memory.view import View
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage.files import FileStore
|
from openhands.storage.files import FileStore
|
||||||
from openhands.storage.locations import get_conversation_agent_state_filename
|
from openhands.storage.locations import get_conversation_agent_state_filename
|
||||||
|
|
||||||
@@ -84,6 +85,7 @@ class State:
|
|||||||
limit_increase_amount=100, current_value=0, max_value=100
|
limit_increase_amount=100, current_value=0, max_value=100
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
convo_stats: ConversationStats | None = None
|
||||||
budget_flag: BudgetControlFlag | None = None
|
budget_flag: BudgetControlFlag | None = None
|
||||||
confirmation_mode: bool = False
|
confirmation_mode: bool = False
|
||||||
history: list[Event] = field(default_factory=list)
|
history: list[Event] = field(default_factory=list)
|
||||||
@@ -91,8 +93,7 @@ class State:
|
|||||||
outputs: dict = field(default_factory=dict)
|
outputs: dict = field(default_factory=dict)
|
||||||
agent_state: AgentState = AgentState.LOADING
|
agent_state: AgentState = AgentState.LOADING
|
||||||
resume_state: AgentState | None = None
|
resume_state: AgentState | None = None
|
||||||
# global metrics for the current task
|
|
||||||
metrics: Metrics = field(default_factory=Metrics)
|
|
||||||
# root agent has level 0, and every delegate increases the level by one
|
# root agent has level 0, and every delegate increases the level by one
|
||||||
delegate_level: int = 0
|
delegate_level: int = 0
|
||||||
# start_id and end_id track the range of events in history
|
# start_id and end_id track the range of events in history
|
||||||
@@ -116,9 +117,14 @@ class State:
|
|||||||
local_metrics: Metrics | None = None
|
local_metrics: Metrics | None = None
|
||||||
delegates: dict[tuple[int, int], tuple[str, str]] | None = None
|
delegates: dict[tuple[int, int], tuple[str, str]] | None = None
|
||||||
|
|
||||||
|
metrics: Metrics = field(default_factory=Metrics)
|
||||||
|
|
||||||
def save_to_session(
|
def save_to_session(
|
||||||
self, sid: str, file_store: FileStore, user_id: str | None
|
self, sid: str, file_store: FileStore, user_id: str | None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
convo_stats = self.convo_stats
|
||||||
|
self.convo_stats = None # Don't save convo stats, handles itself
|
||||||
|
|
||||||
pickled = pickle.dumps(self)
|
pickled = pickle.dumps(self)
|
||||||
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
logger.debug(f'Saving state to session {sid}:{self.agent_state}')
|
||||||
encoded = base64.b64encode(pickled).decode('utf-8')
|
encoded = base64.b64encode(pickled).decode('utf-8')
|
||||||
@@ -138,6 +144,8 @@ class State:
|
|||||||
logger.error(f'Failed to save state to session: {e}')
|
logger.error(f'Failed to save state to session: {e}')
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
self.convo_stats = convo_stats # restore reference
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def restore_from_session(
|
def restore_from_session(
|
||||||
sid: str, file_store: FileStore, user_id: str | None = None
|
sid: str, file_store: FileStore, user_id: str | None = None
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from openhands.controller.agent import Agent
|
|
||||||
from openhands.controller.state.control_flags import (
|
from openhands.controller.state.control_flags import (
|
||||||
BudgetControlFlag,
|
BudgetControlFlag,
|
||||||
IterationControlFlag,
|
IterationControlFlag,
|
||||||
@@ -14,7 +13,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
|
|||||||
from openhands.events.observation.empty import NullObservation
|
from openhands.events.observation.empty import NullObservation
|
||||||
from openhands.events.serialization.event import event_to_trajectory
|
from openhands.events.serialization.event import event_to_trajectory
|
||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage.files import FileStore
|
from openhands.storage.files import FileStore
|
||||||
|
|
||||||
|
|
||||||
@@ -51,8 +50,8 @@ class StateTracker:
|
|||||||
def set_initial_state(
|
def set_initial_state(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
agent: Agent,
|
|
||||||
state: State | None,
|
state: State | None,
|
||||||
|
convo_stats: ConversationStats,
|
||||||
max_iterations: int,
|
max_iterations: int,
|
||||||
max_budget_per_task: float | None,
|
max_budget_per_task: float | None,
|
||||||
confirmation_mode: bool = False,
|
confirmation_mode: bool = False,
|
||||||
@@ -75,6 +74,7 @@ class StateTracker:
|
|||||||
session_id=id.removesuffix('-delegate'),
|
session_id=id.removesuffix('-delegate'),
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
inputs={},
|
inputs={},
|
||||||
|
convo_stats=convo_stats,
|
||||||
iteration_flag=IterationControlFlag(
|
iteration_flag=IterationControlFlag(
|
||||||
limit_increase_amount=max_iterations,
|
limit_increase_amount=max_iterations,
|
||||||
current_value=0,
|
current_value=0,
|
||||||
@@ -99,13 +99,7 @@ class StateTracker:
|
|||||||
if self.state.start_id <= -1:
|
if self.state.start_id <= -1:
|
||||||
self.state.start_id = 0
|
self.state.start_id = 0
|
||||||
|
|
||||||
logger.info(
|
state.convo_stats = convo_stats
|
||||||
f'AgentController {id} initializing history from event {self.state.start_id}',
|
|
||||||
)
|
|
||||||
|
|
||||||
# Share the state metrics with the agent's LLM metrics
|
|
||||||
# This ensures that all accumulated metrics are always in sync between controller and llm
|
|
||||||
agent.llm.metrics = self.state.metrics
|
|
||||||
|
|
||||||
def _init_history(self, event_stream: EventStream) -> None:
|
def _init_history(self, event_stream: EventStream) -> None:
|
||||||
"""Initializes the agent's history from the event stream.
|
"""Initializes the agent's history from the event stream.
|
||||||
@@ -254,6 +248,9 @@ class StateTracker:
|
|||||||
if self.sid and self.file_store:
|
if self.sid and self.file_store:
|
||||||
self.state.save_to_session(self.sid, self.file_store, self.user_id)
|
self.state.save_to_session(self.sid, self.file_store, self.user_id)
|
||||||
|
|
||||||
|
if self.state.convo_stats:
|
||||||
|
self.state.convo_stats.save_metrics()
|
||||||
|
|
||||||
def run_control_flags(self):
|
def run_control_flags(self):
|
||||||
"""Performs one step of the control flags"""
|
"""Performs one step of the control flags"""
|
||||||
self.state.iteration_flag.step()
|
self.state.iteration_flag.step()
|
||||||
@@ -264,20 +261,8 @@ class StateTracker:
|
|||||||
"""Ensures that budget flag is up to date with accumulated costs from llm completions
|
"""Ensures that budget flag is up to date with accumulated costs from llm completions
|
||||||
Budget flag will monitor for when budget is exceeded
|
Budget flag will monitor for when budget is exceeded
|
||||||
"""
|
"""
|
||||||
if self.state.budget_flag:
|
# Sync cost across all llm services from llm registry
|
||||||
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
|
if self.state.budget_flag and self.state.convo_stats:
|
||||||
|
self.state.budget_flag.current_value = (
|
||||||
def merge_metrics(self, metrics: Metrics):
|
self.state.convo_stats.get_combined_metrics().accumulated_cost
|
||||||
"""Merges metrics with the state metrics
|
)
|
||||||
|
|
||||||
NOTE: this should be refactored in the future. We should have services (draft llm, title autocomplete, condenser, etc)
|
|
||||||
use their own LLMs, but the metrics object should be shared. This way we have one source of truth for accumulated costs from
|
|
||||||
all services
|
|
||||||
|
|
||||||
This would prevent having fragmented stores for metrics, and we don't have the burden of deciding where and how to store them
|
|
||||||
if we decide introduce more specialized services that require llm completions
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.state.metrics.merge(metrics)
|
|
||||||
if self.state.budget_flag:
|
|
||||||
self.state.budget_flag.current_value = self.state.metrics.accumulated_cost
|
|
||||||
|
|||||||
@@ -157,13 +157,16 @@ class OpenHandsConfig(BaseModel):
|
|||||||
"""Get a map of agent names to llm configs."""
|
"""Get a map of agent names to llm configs."""
|
||||||
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
|
return {name: self.get_llm_config_from_agent(name) for name in self.agents}
|
||||||
|
|
||||||
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
|
def get_llm_config_from_agent_config(self, agent_config: AgentConfig):
|
||||||
agent_config: AgentConfig = self.get_agent_config(name)
|
|
||||||
llm_config_name = (
|
llm_config_name = (
|
||||||
agent_config.llm_config if agent_config.llm_config is not None else 'llm'
|
agent_config.llm_config if agent_config.llm_config is not None else 'llm'
|
||||||
)
|
)
|
||||||
return self.get_llm_config(llm_config_name)
|
return self.get_llm_config(llm_config_name)
|
||||||
|
|
||||||
|
def get_llm_config_from_agent(self, name: str = 'agent') -> LLMConfig:
|
||||||
|
agent_config: AgentConfig = self.get_agent_config(name)
|
||||||
|
return self.get_llm_config_from_agent_config(agent_config)
|
||||||
|
|
||||||
def get_agent_configs(self) -> dict[str, AgentConfig]:
|
def get_agent_configs(self) -> dict[str, AgentConfig]:
|
||||||
return self.agents
|
return self.agents
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Callable, Protocol
|
|||||||
|
|
||||||
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
import openhands.agenthub # noqa F401 (we import this to get the agents registered)
|
||||||
import openhands.cli.suppress_warnings # noqa: F401
|
import openhands.cli.suppress_warnings # noqa: F401
|
||||||
from openhands.controller.agent import Agent
|
|
||||||
from openhands.controller.replay import ReplayManager
|
from openhands.controller.replay import ReplayManager
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.core.config import (
|
from openhands.core.config import (
|
||||||
@@ -33,10 +32,12 @@ from openhands.events.action.action import Action
|
|||||||
from openhands.events.event import Event
|
from openhands.events.event import Event
|
||||||
from openhands.events.observation import AgentStateChangedObservation
|
from openhands.events.observation import AgentStateChangedObservation
|
||||||
from openhands.io import read_input, read_task
|
from openhands.io import read_input, read_task
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.mcp import add_mcp_tools_to_agent
|
from openhands.mcp import add_mcp_tools_to_agent
|
||||||
from openhands.memory.memory import Memory
|
from openhands.memory.memory import Memory
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.utils.async_utils import call_async_from_sync
|
from openhands.utils.async_utils import call_async_from_sync
|
||||||
|
from openhands.utils.utils import create_registry_and_convo_stats
|
||||||
|
|
||||||
|
|
||||||
class FakeUserResponseFunc(Protocol):
|
class FakeUserResponseFunc(Protocol):
|
||||||
@@ -53,12 +54,12 @@ async def run_controller(
|
|||||||
initial_user_action: Action,
|
initial_user_action: Action,
|
||||||
sid: str | None = None,
|
sid: str | None = None,
|
||||||
runtime: Runtime | None = None,
|
runtime: Runtime | None = None,
|
||||||
agent: Agent | None = None,
|
|
||||||
exit_on_message: bool = False,
|
exit_on_message: bool = False,
|
||||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||||
headless_mode: bool = True,
|
headless_mode: bool = True,
|
||||||
memory: Memory | None = None,
|
memory: Memory | None = None,
|
||||||
conversation_instructions: str | None = None,
|
conversation_instructions: str | None = None,
|
||||||
|
llm_registry: LLMRegistry | None = None,
|
||||||
) -> State | None:
|
) -> State | None:
|
||||||
"""Main coroutine to run the agent controller with task input flexibility.
|
"""Main coroutine to run the agent controller with task input flexibility.
|
||||||
|
|
||||||
@@ -70,7 +71,6 @@ async def run_controller(
|
|||||||
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
|
sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
|
||||||
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
|
Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
|
||||||
runtime: (optional) A runtime for the agent to run on.
|
runtime: (optional) A runtime for the agent to run on.
|
||||||
agent: (optional) A agent to run.
|
|
||||||
exit_on_message: quit if agent asks for a message from user (optional)
|
exit_on_message: quit if agent asks for a message from user (optional)
|
||||||
fake_user_response_fn: An optional function that receives the current state
|
fake_user_response_fn: An optional function that receives the current state
|
||||||
(could be None) and returns a fake user response.
|
(could be None) and returns a fake user response.
|
||||||
@@ -98,8 +98,13 @@ async def run_controller(
|
|||||||
"""
|
"""
|
||||||
sid = sid or generate_sid(config)
|
sid = sid or generate_sid(config)
|
||||||
|
|
||||||
if agent is None:
|
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||||
agent = create_agent(config)
|
config,
|
||||||
|
sid,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent(config, llm_registry)
|
||||||
|
|
||||||
# when the runtime is created, it will be connected and clone the selected repository
|
# when the runtime is created, it will be connected and clone the selected repository
|
||||||
repo_directory = None
|
repo_directory = None
|
||||||
@@ -108,6 +113,7 @@ async def run_controller(
|
|||||||
repo_tokens = get_provider_tokens()
|
repo_tokens = get_provider_tokens()
|
||||||
runtime = create_runtime(
|
runtime = create_runtime(
|
||||||
config,
|
config,
|
||||||
|
llm_registry,
|
||||||
sid=sid,
|
sid=sid,
|
||||||
headless_mode=headless_mode,
|
headless_mode=headless_mode,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
@@ -159,7 +165,7 @@ async def run_controller(
|
|||||||
)
|
)
|
||||||
|
|
||||||
controller, initial_state = create_controller(
|
controller, initial_state = create_controller(
|
||||||
agent, runtime, config, replay_events=replay_events
|
agent, runtime, config, convo_stats, replay_events=replay_events
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(initial_user_action, Action), (
|
assert isinstance(initial_user_action, Action), (
|
||||||
|
|||||||
@@ -21,12 +21,13 @@ from openhands.integrations.provider import (
|
|||||||
ProviderToken,
|
ProviderToken,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
)
|
)
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.memory import Memory
|
from openhands.memory.memory import Memory
|
||||||
from openhands.microagent.microagent import BaseMicroagent
|
from openhands.microagent.microagent import BaseMicroagent
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.security import SecurityAnalyzer, options
|
from openhands.security import SecurityAnalyzer, options
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage import get_file_store
|
from openhands.storage import get_file_store
|
||||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||||
@@ -34,6 +35,7 @@ from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
|||||||
|
|
||||||
def create_runtime(
|
def create_runtime(
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str | None = None,
|
sid: str | None = None,
|
||||||
headless_mode: bool = True,
|
headless_mode: bool = True,
|
||||||
agent: Agent | None = None,
|
agent: Agent | None = None,
|
||||||
@@ -82,6 +84,7 @@ def create_runtime(
|
|||||||
sid=session_id,
|
sid=session_id,
|
||||||
plugins=agent_cls.sandbox_plugins,
|
plugins=agent_cls.sandbox_plugins,
|
||||||
headless_mode=headless_mode,
|
headless_mode=headless_mode,
|
||||||
|
llm_registry=llm_registry,
|
||||||
git_provider_tokens=git_provider_tokens,
|
git_provider_tokens=git_provider_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -203,16 +206,11 @@ def create_memory(
|
|||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
|
||||||
def create_agent(config: OpenHandsConfig) -> Agent:
|
def create_agent(config: OpenHandsConfig, llm_registry: LLMRegistry) -> Agent:
|
||||||
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
|
agent_cls: type[Agent] = Agent.get_cls(config.default_agent)
|
||||||
agent_config = config.get_agent_config(config.default_agent)
|
agent_config = config.get_agent_config(config.default_agent)
|
||||||
llm_config = config.get_llm_config_from_agent(config.default_agent)
|
config.get_llm_config_from_agent(config.default_agent)
|
||||||
|
agent = agent_cls(config=agent_config, llm_registry=llm_registry)
|
||||||
agent = agent_cls(
|
|
||||||
llm=LLM(config=llm_config),
|
|
||||||
config=agent_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
@@ -220,6 +218,7 @@ def create_controller(
|
|||||||
agent: Agent,
|
agent: Agent,
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
|
convo_stats: ConversationStats,
|
||||||
headless_mode: bool = True,
|
headless_mode: bool = True,
|
||||||
replay_events: list[Event] | None = None,
|
replay_events: list[Event] | None = None,
|
||||||
) -> tuple[AgentController, State | None]:
|
) -> tuple[AgentController, State | None]:
|
||||||
@@ -237,6 +236,7 @@ def create_controller(
|
|||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
convo_stats=convo_stats,
|
||||||
iteration_delta=config.max_iterations,
|
iteration_delta=config.max_iterations,
|
||||||
budget_per_task_delta=config.max_budget_per_task,
|
budget_per_task_delta=config.max_budget_per_task,
|
||||||
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
agent_to_llm_config=config.get_agent_to_llm_config_map(),
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Any, Callable
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from openhands.core.config import LLMConfig
|
from openhands.core.config import LLMConfig
|
||||||
|
from openhands.llm.metrics import Metrics
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore')
|
warnings.simplefilter('ignore')
|
||||||
@@ -34,7 +35,6 @@ from openhands.llm.fn_call_converter import (
|
|||||||
convert_fncall_messages_to_non_fncall_messages,
|
convert_fncall_messages_to_non_fncall_messages,
|
||||||
convert_non_fncall_messages_to_fncall_messages,
|
convert_non_fncall_messages_to_fncall_messages,
|
||||||
)
|
)
|
||||||
from openhands.llm.metrics import Metrics
|
|
||||||
from openhands.llm.retry_mixin import RetryMixin
|
from openhands.llm.retry_mixin import RetryMixin
|
||||||
|
|
||||||
__all__ = ['LLM']
|
__all__ = ['LLM']
|
||||||
@@ -133,6 +133,7 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LLMConfig,
|
config: LLMConfig,
|
||||||
|
service_id: str,
|
||||||
metrics: Metrics | None = None,
|
metrics: Metrics | None = None,
|
||||||
retry_listener: Callable[[int, int], None] | None = None,
|
retry_listener: Callable[[int, int], None] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -145,11 +146,12 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
metrics: The metrics to use.
|
metrics: The metrics to use.
|
||||||
"""
|
"""
|
||||||
self._tried_model_info = False
|
self._tried_model_info = False
|
||||||
|
self.cost_metric_supported: bool = True
|
||||||
|
self.config: LLMConfig = copy.deepcopy(config)
|
||||||
|
self.service_id = service_id
|
||||||
self.metrics: Metrics = (
|
self.metrics: Metrics = (
|
||||||
metrics if metrics is not None else Metrics(model_name=config.model)
|
metrics if metrics is not None else Metrics(model_name=config.model)
|
||||||
)
|
)
|
||||||
self.cost_metric_supported: bool = True
|
|
||||||
self.config: LLMConfig = copy.deepcopy(config)
|
|
||||||
|
|
||||||
self.model_info: ModelInfo | None = None
|
self.model_info: ModelInfo | None = None
|
||||||
self.retry_listener = retry_listener
|
self.retry_listener = retry_listener
|
||||||
@@ -408,8 +410,7 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
assert self.config.log_completions_folder is not None
|
assert self.config.log_completions_folder is not None
|
||||||
log_file = os.path.join(
|
log_file = os.path.join(
|
||||||
self.config.log_completions_folder,
|
self.config.log_completions_folder,
|
||||||
# use the metric model name (for draft editor)
|
f'{self.config.model.replace("/", "__")}-{time.time()}.json',
|
||||||
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# set up the dict to be logged
|
# set up the dict to be logged
|
||||||
|
|||||||
132
openhands/llm/llm_registry.py
Normal file
132
openhands/llm/llm_registry.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Any, Callable
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from openhands.core.config.agent_config import AgentConfig
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.llm.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryEvent(BaseModel):
|
||||||
|
llm: LLM
|
||||||
|
service_id: str
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRegistry:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: OpenHandsConfig,
|
||||||
|
agent_cls: str | None = None,
|
||||||
|
retry_listener: Callable[[int, int], None] | None = None,
|
||||||
|
):
|
||||||
|
self.registry_id = str(uuid4())
|
||||||
|
self.config = copy.deepcopy(config)
|
||||||
|
self.retry_listner = retry_listener
|
||||||
|
self.agent_to_llm_config = self.config.get_agent_to_llm_config_map()
|
||||||
|
self.service_to_llm: dict[str, LLM] = {}
|
||||||
|
self.subscriber: Callable[[Any], None] | None = None
|
||||||
|
|
||||||
|
selected_agent_cls = self.config.default_agent
|
||||||
|
if agent_cls:
|
||||||
|
selected_agent_cls = agent_cls
|
||||||
|
|
||||||
|
agent_name = selected_agent_cls if selected_agent_cls is not None else 'agent'
|
||||||
|
llm_config = self.config.get_llm_config_from_agent(agent_name)
|
||||||
|
self.active_agent_llm: LLM = self.get_llm('agent', llm_config)
|
||||||
|
|
||||||
|
def _create_new_llm(
|
||||||
|
self, service_id: str, config: LLMConfig, with_listener: bool = True
|
||||||
|
) -> LLM:
|
||||||
|
if with_listener:
|
||||||
|
llm = LLM(
|
||||||
|
service_id=service_id, config=config, retry_listener=self.retry_listner
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llm = LLM(service_id=service_id, config=config)
|
||||||
|
self.service_to_llm[service_id] = llm
|
||||||
|
self.notify(RegistryEvent(llm=llm, service_id=service_id))
|
||||||
|
return llm
|
||||||
|
|
||||||
|
def request_extraneous_completion(
|
||||||
|
self, service_id: str, llm_config: LLMConfig, messages: list[dict[str, str]]
|
||||||
|
) -> str:
|
||||||
|
logger.info(f'extraneous completion: {service_id}')
|
||||||
|
if service_id not in self.service_to_llm:
|
||||||
|
self._create_new_llm(
|
||||||
|
config=llm_config, service_id=service_id, with_listener=False
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = self.service_to_llm[service_id]
|
||||||
|
response = llm.completion(messages=messages)
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
def get_llm_from_agent_config(self, service_id: str, agent_config: AgentConfig):
|
||||||
|
llm_config = self.config.get_llm_config_from_agent_config(agent_config)
|
||||||
|
if service_id in self.service_to_llm:
|
||||||
|
if self.service_to_llm[service_id].config != llm_config:
|
||||||
|
# TODO: update llm config internally
|
||||||
|
# Done when agent delegates has different config, we should reuse the existing LLM
|
||||||
|
pass
|
||||||
|
return self.service_to_llm[service_id]
|
||||||
|
|
||||||
|
return self._create_new_llm(config=llm_config, service_id=service_id)
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
self,
|
||||||
|
service_id: str,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f'[LLM registry {self.registry_id}]: Registering service for {service_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempting to switch configs for existing LLM
|
||||||
|
if (
|
||||||
|
service_id in self.service_to_llm
|
||||||
|
and self.service_to_llm[service_id].config != config
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f'Requesting same service ID {service_id} with different config, use a new service ID'
|
||||||
|
)
|
||||||
|
|
||||||
|
if service_id in self.service_to_llm:
|
||||||
|
return self.service_to_llm[service_id]
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise ValueError('Requesting new LLM without specifying LLM config')
|
||||||
|
|
||||||
|
return self._create_new_llm(config=config, service_id=service_id)
|
||||||
|
|
||||||
|
def get_active_llm(self) -> LLM:
|
||||||
|
return self.active_agent_llm
|
||||||
|
|
||||||
|
def _set_active_llm(self, service_id) -> None:
|
||||||
|
if service_id not in self.service_to_llm:
|
||||||
|
raise ValueError(f'Unrecognized service ID: {service_id}')
|
||||||
|
self.active_agent_llm = self.service_to_llm[service_id]
|
||||||
|
|
||||||
|
def subscribe(self, callback: Callable[[RegistryEvent], None]) -> None:
|
||||||
|
self.subscriber = callback
|
||||||
|
|
||||||
|
# Subscriptions happen after default llm is initialized
|
||||||
|
# Notify service of this llm
|
||||||
|
self.notify(
|
||||||
|
RegistryEvent(
|
||||||
|
llm=self.active_agent_llm, service_id=self.active_agent_llm.service_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def notify(self, event: RegistryEvent):
|
||||||
|
if self.subscriber:
|
||||||
|
try:
|
||||||
|
self.subscriber(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Failed to emit event: {e}')
|
||||||
@@ -10,6 +10,7 @@ from openhands.controller.state.state import State
|
|||||||
from openhands.core.config.condenser_config import CondenserConfig
|
from openhands.core.config.condenser_config import CondenserConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events.action.agent import CondensationAction
|
from openhands.events.action.agent import CondensationAction
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.view import View
|
from openhands.memory.view import View
|
||||||
|
|
||||||
CONDENSER_METADATA_KEY = 'condenser_meta'
|
CONDENSER_METADATA_KEY = 'condenser_meta'
|
||||||
@@ -144,7 +145,9 @@ class Condenser(ABC):
|
|||||||
CONDENSER_REGISTRY[configuration_type] = cls
|
CONDENSER_REGISTRY[configuration_type] = cls
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: CondenserConfig) -> Condenser:
|
def from_config(
|
||||||
|
cls, config: CondenserConfig, llm_registry: LLMRegistry
|
||||||
|
) -> Condenser:
|
||||||
"""Create a condenser from a configuration object.
|
"""Create a condenser from a configuration object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -158,7 +161,7 @@ class Condenser(ABC):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
condenser_class = CONDENSER_REGISTRY[type(config)]
|
condenser_class = CONDENSER_REGISTRY[type(config)]
|
||||||
return condenser_class.from_config(config)
|
return condenser_class.from_config(config, llm_registry)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError(f'Unknown condenser config: {config}')
|
raise ValueError(f'Unknown condenser config: {config}')
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
|
from openhands.core.config.condenser_config import AmortizedForgettingCondenserConfig
|
||||||
from openhands.events.action.agent import CondensationAction
|
from openhands.events.action.agent import CondensationAction
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import (
|
from openhands.memory.condenser.condenser import (
|
||||||
Condensation,
|
Condensation,
|
||||||
RollingCondenser,
|
RollingCondenser,
|
||||||
@@ -58,7 +59,9 @@ class AmortizedForgettingCondenser(RollingCondenser):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls, config: AmortizedForgettingCondenserConfig
|
cls,
|
||||||
|
config: AmortizedForgettingCondenserConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> AmortizedForgettingCondenser:
|
) -> AmortizedForgettingCondenser:
|
||||||
return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'}))
|
return AmortizedForgettingCondenser(**config.model_dump(exclude={'type'}))
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import BrowserOutputCondenserConfig
|
|||||||
from openhands.events.event import Event
|
from openhands.events.event import Event
|
||||||
from openhands.events.observation import BrowserOutputObservation
|
from openhands.events.observation import BrowserOutputObservation
|
||||||
from openhands.events.observation.agent import AgentCondensationObservation
|
from openhands.events.observation.agent import AgentCondensationObservation
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +41,7 @@ class BrowserOutputCondenser(Condenser):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls, config: BrowserOutputCondenserConfig
|
cls, config: BrowserOutputCondenserConfig, llm_registry: LLMRegistry
|
||||||
) -> BrowserOutputCondenser:
|
) -> BrowserOutputCondenser:
|
||||||
return BrowserOutputCondenser(**config.model_dump(exclude={'type'}))
|
return BrowserOutputCondenser(**config.model_dump(exclude={'type'}))
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from openhands.events.action.agent import (
|
|||||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.observation import Observation
|
from openhands.events.observation import Observation
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
||||||
|
|
||||||
|
|
||||||
@@ -177,7 +178,9 @@ class ConversationWindowCondenser(RollingCondenser):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls, _config: ConversationWindowCondenserConfig
|
cls,
|
||||||
|
_config: ConversationWindowCondenserConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> ConversationWindowCondenser:
|
) -> ConversationWindowCondenser:
|
||||||
return ConversationWindowCondenser()
|
return ConversationWindowCondenser()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|||||||
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
|
from openhands.core.config.condenser_config import LLMAttentionCondenserConfig
|
||||||
from openhands.events.action.agent import CondensationAction
|
from openhands.events.action.agent import CondensationAction
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import (
|
from openhands.memory.condenser.condenser import (
|
||||||
Condensation,
|
Condensation,
|
||||||
RollingCondenser,
|
RollingCondenser,
|
||||||
@@ -22,7 +23,12 @@ class ImportantEventSelection(BaseModel):
|
|||||||
class LLMAttentionCondenser(RollingCondenser):
|
class LLMAttentionCondenser(RollingCondenser):
|
||||||
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
|
"""Rolling condenser strategy that uses an LLM to select the most important events when condensing the history."""
|
||||||
|
|
||||||
def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm: LLM,
|
||||||
|
max_size: int = 100,
|
||||||
|
keep_first: int = 1,
|
||||||
|
):
|
||||||
if keep_first >= max_size // 2:
|
if keep_first >= max_size // 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
f'keep_first ({keep_first}) must be less than half of max_size ({max_size})'
|
||||||
@@ -113,15 +119,19 @@ class LLMAttentionCondenser(RollingCondenser):
|
|||||||
return len(view) > self.max_size
|
return len(view) > self.max_size
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: LLMAttentionCondenserConfig) -> LLMAttentionCondenser:
|
def from_config(
|
||||||
|
cls, config: LLMAttentionCondenserConfig, llm_registry: LLMRegistry
|
||||||
|
) -> LLMAttentionCondenser:
|
||||||
# This condenser cannot take advantage of prompt caching. If it happens
|
# This condenser cannot take advantage of prompt caching. If it happens
|
||||||
# to be set, we'll pay for the cache writes but never get a chance to
|
# to be set, we'll pay for the cache writes but never get a chance to
|
||||||
# save on a read.
|
# save on a read.
|
||||||
llm_config = config.llm_config.model_copy()
|
llm_config = config.llm_config.model_copy()
|
||||||
llm_config.caching_prompt = False
|
llm_config.caching_prompt = False
|
||||||
|
|
||||||
|
llm = llm_registry.get_llm('condenser', llm_config)
|
||||||
|
|
||||||
return LLMAttentionCondenser(
|
return LLMAttentionCondenser(
|
||||||
llm=LLM(config=llm_config),
|
llm=llm,
|
||||||
max_size=config.max_size,
|
max_size=config.max_size,
|
||||||
keep_first=config.keep_first,
|
keep_first=config.keep_first,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from openhands.core.message import Message, TextContent
|
|||||||
from openhands.events.action.agent import CondensationAction
|
from openhands.events.action.agent import CondensationAction
|
||||||
from openhands.events.observation.agent import AgentCondensationObservation
|
from openhands.events.observation.agent import AgentCondensationObservation
|
||||||
from openhands.events.serialization.event import truncate_content
|
from openhands.events.serialization.event import truncate_content
|
||||||
from openhands.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import (
|
from openhands.memory.condenser.condenser import (
|
||||||
Condensation,
|
Condensation,
|
||||||
RollingCondenser,
|
RollingCondenser,
|
||||||
@@ -154,16 +155,17 @@ CURRENT_STATE: Last flip: Heads, Haiku count: 15/20"""
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls, config: LLMSummarizingCondenserConfig
|
cls, config: LLMSummarizingCondenserConfig, llm_registry: LLMRegistry
|
||||||
) -> LLMSummarizingCondenser:
|
) -> LLMSummarizingCondenser:
|
||||||
# This condenser cannot take advantage of prompt caching. If it happens
|
# This condenser cannot take advantage of prompt caching. If it happens
|
||||||
# to be set, we'll pay for the cache writes but never get a chance to
|
# to be set, we'll pay for the cache writes but never get a chance to
|
||||||
# save on a read.
|
# save on a read.
|
||||||
llm_config = config.llm_config.model_copy()
|
llm_config = config.llm_config.model_copy()
|
||||||
llm_config.caching_prompt = False
|
llm_config.caching_prompt = False
|
||||||
|
llm = llm_registry.get_llm('condenser', llm_config)
|
||||||
|
|
||||||
return LLMSummarizingCondenser(
|
return LLMSummarizingCondenser(
|
||||||
llm=LLM(config=llm_config),
|
llm=llm,
|
||||||
max_size=config.max_size,
|
max_size=config.max_size,
|
||||||
keep_first=config.keep_first,
|
keep_first=config.keep_first,
|
||||||
max_event_length=config.max_event_length,
|
max_event_length=config.max_event_length,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||||
|
|
||||||
|
|
||||||
@@ -12,7 +13,9 @@ class NoOpCondenser(Condenser):
|
|||||||
return view
|
return view
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:
|
def from_config(
|
||||||
|
cls, config: NoOpCondenserConfig, llm_registry: LLMRegistry
|
||||||
|
) -> NoOpCondenser:
|
||||||
return NoOpCondenser()
|
return NoOpCondenser()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from openhands.core.config.condenser_config import ObservationMaskingCondenserCo
|
|||||||
from openhands.events.event import Event
|
from openhands.events.event import Event
|
||||||
from openhands.events.observation import Observation
|
from openhands.events.observation import Observation
|
||||||
from openhands.events.observation.agent import AgentCondensationObservation
|
from openhands.events.observation.agent import AgentCondensationObservation
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||||
|
|
||||||
|
|
||||||
@@ -28,7 +29,9 @@ class ObservationMaskingCondenser(Condenser):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls, config: ObservationMaskingCondenserConfig
|
cls,
|
||||||
|
config: ObservationMaskingCondenserConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> ObservationMaskingCondenser:
|
) -> ObservationMaskingCondenser:
|
||||||
return ObservationMaskingCondenser(**config.model_dump(exclude={'type'}))
|
return ObservationMaskingCondenser(**config.model_dump(exclude={'type'}))
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from contextlib import contextmanager
|
|||||||
|
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.core.config.condenser_config import CondenserPipelineConfig
|
from openhands.core.config.condenser_config import CondenserPipelineConfig
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import Condensation, Condenser
|
from openhands.memory.condenser.condenser import Condensation, Condenser
|
||||||
from openhands.memory.view import View
|
from openhands.memory.view import View
|
||||||
|
|
||||||
@@ -39,8 +40,10 @@ class CondenserPipeline(Condenser):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: CondenserPipelineConfig) -> CondenserPipeline:
|
def from_config(
|
||||||
condensers = [Condenser.from_config(c) for c in config.condensers]
|
cls, config: CondenserPipelineConfig, llm_registry: LLMRegistry
|
||||||
|
) -> CondenserPipeline:
|
||||||
|
condensers = [Condenser.from_config(c, llm_registry) for c in config.condensers]
|
||||||
return CondenserPipeline(*condensers)
|
return CondenserPipeline(*condensers)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
|
from openhands.core.config.condenser_config import RecentEventsCondenserConfig
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
from openhands.memory.condenser.condenser import Condensation, Condenser, View
|
||||||
|
|
||||||
|
|
||||||
@@ -21,7 +22,9 @@ class RecentEventsCondenser(Condenser):
|
|||||||
return View(events=head + tail)
|
return View(events=head + tail)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: RecentEventsCondenserConfig) -> RecentEventsCondenser:
|
def from_config(
|
||||||
|
cls, config: RecentEventsCondenserConfig, llm_registry: LLMRegistry
|
||||||
|
) -> RecentEventsCondenser:
|
||||||
return RecentEventsCondenser(**config.model_dump(exclude={'type'}))
|
return RecentEventsCondenser(**config.model_dump(exclude={'type'}))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ from openhands.core.message import Message, TextContent
|
|||||||
from openhands.events.action.agent import CondensationAction
|
from openhands.events.action.agent import CondensationAction
|
||||||
from openhands.events.observation.agent import AgentCondensationObservation
|
from openhands.events.observation.agent import AgentCondensationObservation
|
||||||
from openhands.events.serialization.event import truncate_content
|
from openhands.events.serialization.event import truncate_content
|
||||||
from openhands.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser.condenser import (
|
from openhands.memory.condenser.condenser import (
|
||||||
Condensation,
|
Condensation,
|
||||||
RollingCondenser,
|
RollingCondenser,
|
||||||
@@ -180,15 +181,14 @@ class StructuredSummaryCondenser(RollingCondenser):
|
|||||||
if max_size < 1:
|
if max_size < 1:
|
||||||
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
raise ValueError(f'max_size ({max_size}) cannot be non-positive')
|
||||||
|
|
||||||
if not llm.is_function_calling_active():
|
|
||||||
raise ValueError(
|
|
||||||
'LLM must support function calling to use StructuredSummaryCondenser'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.keep_first = keep_first
|
self.keep_first = keep_first
|
||||||
self.max_event_length = max_event_length
|
self.max_event_length = max_event_length
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
if not self.llm.is_function_calling_active():
|
||||||
|
raise ValueError(
|
||||||
|
'LLM must support function calling to use StructuredSummaryCondenser'
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -309,16 +309,17 @@ Capture all relevant information, especially:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls, config: StructuredSummaryCondenserConfig
|
cls, config: StructuredSummaryCondenserConfig, llm_registry: LLMRegistry
|
||||||
) -> StructuredSummaryCondenser:
|
) -> StructuredSummaryCondenser:
|
||||||
# This condenser cannot take advantage of prompt caching. If it happens
|
# This condenser cannot take advantage of prompt caching. If it happens
|
||||||
# to be set, we'll pay for the cache writes but never get a chance to
|
# to be set, we'll pay for the cache writes but never get a chance to
|
||||||
# save on a read.
|
# save on a read.
|
||||||
llm_config = config.llm_config.model_copy()
|
llm_config = config.llm_config.model_copy()
|
||||||
llm_config.caching_prompt = False
|
llm_config.caching_prompt = False
|
||||||
|
llm = llm_registry.get_llm('condenser', llm_config)
|
||||||
|
|
||||||
return StructuredSummaryCondenser(
|
return StructuredSummaryCondenser(
|
||||||
llm=LLM(config=llm_config),
|
llm=llm,
|
||||||
max_size=config.max_size,
|
max_size=config.max_size,
|
||||||
keep_first=config.keep_first,
|
keep_first=config.keep_first,
|
||||||
max_event_length=config.max_event_length,
|
max_event_length=config.max_event_length,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class ServiceContext:
|
|||||||
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
|
def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None):
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
if llm_config is not None:
|
if llm_config is not None:
|
||||||
self.llm = LLM(llm_config)
|
self.llm = LLM(llm_config, service_id='resolver')
|
||||||
|
|
||||||
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
|
def set_strategy(self, strategy: IssueHandlerInterface) -> None:
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from openhands.events.observation import (
|
|||||||
)
|
)
|
||||||
from openhands.events.stream import EventStreamSubscriber
|
from openhands.events.stream import EventStreamSubscriber
|
||||||
from openhands.integrations.service_types import ProviderType
|
from openhands.integrations.service_types import ProviderType
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.resolver.interfaces.issue import Issue
|
from openhands.resolver.interfaces.issue import Issue
|
||||||
from openhands.resolver.interfaces.issue_definitions import (
|
from openhands.resolver.interfaces.issue_definitions import (
|
||||||
ServiceContextIssue,
|
ServiceContextIssue,
|
||||||
@@ -412,7 +413,8 @@ class IssueResolver:
|
|||||||
shutil.rmtree(self.workspace_base)
|
shutil.rmtree(self.workspace_base)
|
||||||
shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base)
|
shutil.copytree(os.path.join(self.output_dir, 'repo'), self.workspace_base)
|
||||||
|
|
||||||
runtime = create_runtime(self.app_config)
|
llm_registry = LLMRegistry(self.app_config)
|
||||||
|
runtime = create_runtime(self.app_config, llm_registry)
|
||||||
await runtime.connect()
|
await runtime.connect()
|
||||||
|
|
||||||
def on_event(evt: Event) -> None:
|
def on_event(evt: Event) -> None:
|
||||||
|
|||||||
@@ -463,7 +463,7 @@ def update_existing_pull_request(
|
|||||||
|
|
||||||
# Summarize with LLM if provided
|
# Summarize with LLM if provided
|
||||||
if llm_config is not None:
|
if llm_config is not None:
|
||||||
llm = LLM(llm_config)
|
llm = LLM(llm_config, service_id='resolver')
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.path.dirname(__file__),
|
os.path.dirname(__file__),
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ from openhands.integrations.provider import (
|
|||||||
ProviderType,
|
ProviderType,
|
||||||
)
|
)
|
||||||
from openhands.integrations.service_types import AuthenticationError
|
from openhands.integrations.service_types import AuthenticationError
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.microagent import (
|
from openhands.microagent import (
|
||||||
BaseMicroagent,
|
BaseMicroagent,
|
||||||
load_microagents_from_dir,
|
load_microagents_from_dir,
|
||||||
@@ -125,6 +126,7 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -178,7 +180,9 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
|
|
||||||
# Load mixins
|
# Load mixins
|
||||||
FileEditRuntimeMixin.__init__(
|
FileEditRuntimeMixin.__init__(
|
||||||
self, enable_llm_editor=config.get_agent_config().enable_llm_editor
|
self,
|
||||||
|
enable_llm_editor=config.get_agent_config().enable_llm_editor,
|
||||||
|
llm_registry=llm_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from openhands.events.observation import (
|
|||||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.runtime.plugins import PluginRequirement
|
from openhands.runtime.plugins import PluginRequirement
|
||||||
from openhands.runtime.utils.request import send_request
|
from openhands.runtime.utils.request import send_request
|
||||||
@@ -68,6 +69,7 @@ class ActionExecutionClient(Runtime):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -85,6 +87,7 @@ class ActionExecutionClient(Runtime):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
event_stream,
|
event_stream,
|
||||||
|
llm_registry,
|
||||||
sid,
|
sid,
|
||||||
plugins,
|
plugins,
|
||||||
env_vars,
|
env_vars,
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from openhands.events.observation import (
|
|||||||
Observation,
|
Observation,
|
||||||
)
|
)
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.runtime.plugins import PluginRequirement
|
from openhands.runtime.plugins import PluginRequirement
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
@@ -107,6 +108,7 @@ class CLIRuntime(Runtime):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -119,6 +121,7 @@ class CLIRuntime(Runtime):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
event_stream,
|
event_stream,
|
||||||
|
llm_registry,
|
||||||
sid,
|
sid,
|
||||||
plugins,
|
plugins,
|
||||||
env_vars,
|
env_vars,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from openhands.core.logger import DEBUG, DEBUG_RUNTIME
|
|||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
@@ -90,6 +91,7 @@ class DockerRuntime(ActionExecutionClient):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -143,6 +145,7 @@ class DockerRuntime(ActionExecutionClient):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
event_stream,
|
event_stream,
|
||||||
|
llm_registry,
|
||||||
sid,
|
sid,
|
||||||
plugins,
|
plugins,
|
||||||
env_vars,
|
env_vars,
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from openhands.core.logger import DEBUG
|
|||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
)
|
)
|
||||||
@@ -81,6 +82,7 @@ class KubernetesRuntime(ActionExecutionClient):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -137,6 +139,7 @@ class KubernetesRuntime(ActionExecutionClient):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
event_stream,
|
event_stream,
|
||||||
|
llm_registry,
|
||||||
sid,
|
sid,
|
||||||
plugins,
|
plugins,
|
||||||
env_vars,
|
env_vars,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from openhands.events.observation import (
|
|||||||
)
|
)
|
||||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
)
|
)
|
||||||
@@ -135,6 +136,7 @@ class LocalRuntime(ActionExecutionClient):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -186,6 +188,7 @@ class LocalRuntime(ActionExecutionClient):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
event_stream,
|
event_stream,
|
||||||
|
llm_registry,
|
||||||
sid,
|
sid,
|
||||||
plugins,
|
plugins,
|
||||||
env_vars,
|
env_vars,
|
||||||
@@ -801,12 +804,6 @@ def _create_warm_server_in_background(
|
|||||||
|
|
||||||
def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]:
|
def _get_plugins(config: OpenHandsConfig) -> list[PluginRequirement]:
|
||||||
from openhands.controller.agent import Agent
|
from openhands.controller.agent import Agent
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
|
|
||||||
agent_config = config.get_agent_config(config.default_agent)
|
plugins = Agent.get_cls(config.default_agent).sandbox_plugins
|
||||||
llm = LLM(
|
|
||||||
config=config.get_llm_config_from_agent(config.default_agent),
|
|
||||||
)
|
|
||||||
agent = Agent.get_cls(config.default_agent)(llm, agent_config)
|
|
||||||
plugins = agent.sandbox_plugins
|
|
||||||
return plugins
|
return plugins
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from openhands.core.exceptions import (
|
|||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
@@ -51,6 +52,7 @@ class RemoteRuntime(ActionExecutionClient):
|
|||||||
self,
|
self,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
event_stream: EventStream,
|
event_stream: EventStream,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
sid: str = 'default',
|
sid: str = 'default',
|
||||||
plugins: list[PluginRequirement] | None = None,
|
plugins: list[PluginRequirement] | None = None,
|
||||||
env_vars: dict[str, str] | None = None,
|
env_vars: dict[str, str] | None = None,
|
||||||
@@ -64,6 +66,7 @@ class RemoteRuntime(ActionExecutionClient):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
event_stream,
|
event_stream,
|
||||||
|
llm_registry,
|
||||||
sid,
|
sid,
|
||||||
plugins,
|
plugins,
|
||||||
env_vars,
|
env_vars,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from openhands.events.observation import (
|
|||||||
)
|
)
|
||||||
from openhands.linter import DefaultLinter
|
from openhands.linter import DefaultLinter
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
||||||
|
|
||||||
USER_MSG = """
|
USER_MSG = """
|
||||||
@@ -128,7 +128,13 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
|||||||
# This restricts the number of lines we can edit to avoid exceeding the token limit.
|
# This restricts the number of lines we can edit to avoid exceeding the token limit.
|
||||||
MAX_LINES_TO_EDIT = 300
|
MAX_LINES_TO_EDIT = 300
|
||||||
|
|
||||||
def __init__(self, enable_llm_editor: bool, *args: Any, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
enable_llm_editor: bool,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.enable_llm_editor = enable_llm_editor
|
self.enable_llm_editor = enable_llm_editor
|
||||||
|
|
||||||
@@ -138,7 +144,6 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
|||||||
draft_editor_config = self.config.get_llm_config('draft_editor')
|
draft_editor_config = self.config.get_llm_config('draft_editor')
|
||||||
|
|
||||||
# manually set the model name for the draft editor LLM to distinguish token costs
|
# manually set the model name for the draft editor LLM to distinguish token costs
|
||||||
llm_metrics = Metrics(model_name='draft_editor:' + draft_editor_config.model)
|
|
||||||
if draft_editor_config.caching_prompt:
|
if draft_editor_config.caching_prompt:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
|
'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. '
|
||||||
@@ -146,7 +151,9 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface):
|
|||||||
)
|
)
|
||||||
draft_editor_config.caching_prompt = False
|
draft_editor_config.caching_prompt = False
|
||||||
|
|
||||||
self.draft_editor_llm = LLM(draft_editor_config, metrics=llm_metrics)
|
self.draft_editor_llm = llm_registry.get_llm(
|
||||||
|
'draft_editor_llm', draft_editor_config
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
|
f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
import socketio
|
import socketio
|
||||||
|
|
||||||
from openhands.core.config import OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
from openhands.server.config.server_config import ServerConfig
|
from openhands.server.config.server_config import ServerConfig
|
||||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||||
@@ -136,6 +137,16 @@ class ConversationManager(ABC):
|
|||||||
) -> list[AgentLoopInfo]:
|
) -> list[AgentLoopInfo]:
|
||||||
"""Get the AgentLoopInfo for conversations."""
|
"""Get the AgentLoopInfo for conversations."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def request_llm_completion(
|
||||||
|
self,
|
||||||
|
sid: str,
|
||||||
|
service_id: str,
|
||||||
|
llm_config: LLMConfig,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Request extraneous llm completions for a conversation"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_instance(
|
def get_instance(
|
||||||
|
|||||||
@@ -15,13 +15,13 @@ from docker.models.containers import Container
|
|||||||
|
|
||||||
from openhands.controller.agent import Agent
|
from openhands.controller.agent import Agent
|
||||||
from openhands.core.config import OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
from openhands.events.nested_event_store import NestedEventStore
|
from openhands.events.nested_event_store import NestedEventStore
|
||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||||
from openhands.server.config.server_config import ServerConfig
|
from openhands.server.config.server_config import ServerConfig
|
||||||
@@ -42,6 +42,7 @@ from openhands.storage.files import FileStore
|
|||||||
from openhands.storage.locations import get_conversation_dir
|
from openhands.storage.locations import get_conversation_dir
|
||||||
from openhands.utils.async_utils import call_sync_from_async
|
from openhands.utils.async_utils import call_sync_from_async
|
||||||
from openhands.utils.import_utils import get_impl
|
from openhands.utils.import_utils import get_impl
|
||||||
|
from openhands.utils.utils import create_registry_and_convo_stats
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -275,6 +276,16 @@ class DockerNestedConversationManager(ConversationManager):
|
|||||||
# Not supported - clients should connect directly to the nested server!
|
# Not supported - clients should connect directly to the nested server!
|
||||||
raise ValueError('unsupported_operation')
|
raise ValueError('unsupported_operation')
|
||||||
|
|
||||||
|
async def request_llm_completion(
|
||||||
|
self,
|
||||||
|
sid: str,
|
||||||
|
service_id: str,
|
||||||
|
llm_config: LLMConfig,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
# Not supported - clients should connect directly to the nested server!
|
||||||
|
raise ValueError('unsupported_operation')
|
||||||
|
|
||||||
async def send_event_to_conversation(self, sid, data):
|
async def send_event_to_conversation(self, sid, data):
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
headers={
|
headers={
|
||||||
@@ -471,27 +482,27 @@ class DockerNestedConversationManager(ConversationManager):
|
|||||||
# This session is created here only because it is the easiest way to get a runtime, which
|
# This session is created here only because it is the easiest way to get a runtime, which
|
||||||
# is the easiest way to create the needed docker container
|
# is the easiest way to create the needed docker container
|
||||||
|
|
||||||
# Run experiment manager variant test before creating session
|
|
||||||
config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test(
|
config: OpenHandsConfig = ExperimentManagerImpl.run_config_variant_test(
|
||||||
user_id, sid, self.config
|
user_id, sid, self.config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||||
|
config, sid, user_id, settings
|
||||||
|
)
|
||||||
|
|
||||||
session = Session(
|
session = Session(
|
||||||
sid=sid,
|
sid=sid,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
convo_stats=convo_stats,
|
||||||
file_store=self.file_store,
|
file_store=self.file_store,
|
||||||
config=config,
|
config=config,
|
||||||
sio=self.sio,
|
sio=self.sio,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
llm_registry.retry_listner = session._notify_on_llm_retry
|
||||||
agent_cls = settings.agent or config.default_agent
|
agent_cls = settings.agent or config.default_agent
|
||||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
|
||||||
llm = LLM(
|
|
||||||
config=config.get_llm_config_from_agent(agent_name),
|
|
||||||
retry_listener=session._notify_on_llm_retry,
|
|
||||||
)
|
|
||||||
llm = session._create_llm(agent_cls)
|
|
||||||
agent_config = config.get_agent_config(agent_cls)
|
agent_config = config.get_agent_config(agent_cls)
|
||||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
agent = Agent.get_cls(agent_cls)(agent_config, llm_registry)
|
||||||
|
|
||||||
config = config.model_copy(deep=True)
|
config = config.model_copy(deep=True)
|
||||||
env_vars = config.sandbox.runtime_startup_env_vars
|
env_vars = config.sandbox.runtime_startup_env_vars
|
||||||
@@ -543,6 +554,7 @@ class DockerNestedConversationManager(ConversationManager):
|
|||||||
headless_mode=False,
|
headless_mode=False,
|
||||||
attach_to_existing=False,
|
attach_to_existing=False,
|
||||||
main_module='openhands.server',
|
main_module='openhands.server',
|
||||||
|
llm_registry=llm_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hack - disable setting initial env.
|
# Hack - disable setting initial env.
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ from typing import Callable, Iterable
|
|||||||
|
|
||||||
import socketio
|
import socketio
|
||||||
|
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.schema.agent import AgentState
|
from openhands.core.schema.agent import AgentState
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
from openhands.events.stream import EventStreamSubscriber, session_exists
|
from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.server.config.server_config import ServerConfig
|
from openhands.server.config.server_config import ServerConfig
|
||||||
from openhands.server.constants import ROOM_KEY
|
from openhands.server.constants import ROOM_KEY
|
||||||
@@ -37,6 +39,7 @@ from openhands.utils.conversation_summary import (
|
|||||||
)
|
)
|
||||||
from openhands.utils.import_utils import get_impl
|
from openhands.utils.import_utils import get_impl
|
||||||
from openhands.utils.shutdown_listener import should_continue
|
from openhands.utils.shutdown_listener import should_continue
|
||||||
|
from openhands.utils.utils import create_registry_and_convo_stats
|
||||||
|
|
||||||
from .conversation_manager import ConversationManager
|
from .conversation_manager import ConversationManager
|
||||||
|
|
||||||
@@ -332,12 +335,15 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
)
|
)
|
||||||
await self.close_session(oldest_conversation_id)
|
await self.close_session(oldest_conversation_id)
|
||||||
|
|
||||||
config = self.config.model_copy(deep=True)
|
llm_registry, convo_stats, config = create_registry_and_convo_stats(
|
||||||
|
self.config, sid, user_id, settings
|
||||||
|
)
|
||||||
session = Session(
|
session = Session(
|
||||||
sid=sid,
|
sid=sid,
|
||||||
file_store=self.file_store,
|
file_store=self.file_store,
|
||||||
config=config,
|
config=config,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
convo_stats=convo_stats,
|
||||||
sio=self.sio,
|
sio=self.sio,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@@ -349,7 +355,9 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
try:
|
try:
|
||||||
session.agent_session.event_stream.subscribe(
|
session.agent_session.event_stream.subscribe(
|
||||||
EventStreamSubscriber.SERVER,
|
EventStreamSubscriber.SERVER,
|
||||||
self._create_conversation_update_callback(user_id, sid, settings),
|
self._create_conversation_update_callback(
|
||||||
|
user_id, sid, settings, session.llm_registry
|
||||||
|
),
|
||||||
UPDATED_AT_CALLBACK_ID,
|
UPDATED_AT_CALLBACK_ID,
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -369,6 +377,21 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
raise RuntimeError(f'no_conversation:{sid}')
|
raise RuntimeError(f'no_conversation:{sid}')
|
||||||
await session.dispatch(data)
|
await session.dispatch(data)
|
||||||
|
|
||||||
|
async def request_llm_completion(
|
||||||
|
self,
|
||||||
|
sid: str,
|
||||||
|
service_id: str,
|
||||||
|
llm_config: LLMConfig,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
):
|
||||||
|
session = self._local_agent_loops_by_sid.get(sid)
|
||||||
|
if not session:
|
||||||
|
raise RuntimeError(f'no_conversation:{sid}')
|
||||||
|
llm_registry = session.llm_registry
|
||||||
|
return llm_registry.request_extraneous_completion(
|
||||||
|
service_id, llm_config, messages
|
||||||
|
)
|
||||||
|
|
||||||
async def disconnect_from_session(self, connection_id: str):
|
async def disconnect_from_session(self, connection_id: str):
|
||||||
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
|
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -450,6 +473,7 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
def callback(event, *args, **kwargs):
|
def callback(event, *args, **kwargs):
|
||||||
call_async_from_sync(
|
call_async_from_sync(
|
||||||
@@ -458,6 +482,7 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
user_id,
|
user_id,
|
||||||
conversation_id,
|
conversation_id,
|
||||||
settings,
|
settings,
|
||||||
|
llm_registry,
|
||||||
event,
|
event,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -468,6 +493,7 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
event=None,
|
event=None,
|
||||||
):
|
):
|
||||||
conversation_store = await self._get_conversation_store(user_id)
|
conversation_store = await self._get_conversation_store(user_id)
|
||||||
@@ -495,7 +521,7 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
conversation.title == default_title
|
conversation.title == default_title
|
||||||
): # attempt to autogenerate if default title is in use
|
): # attempt to autogenerate if default title is in use
|
||||||
title = await auto_generate_title(
|
title = await auto_generate_title(
|
||||||
conversation_id, user_id, self.file_store, settings
|
conversation_id, user_id, self.file_store, settings, llm_registry
|
||||||
)
|
)
|
||||||
if title and not title.isspace():
|
if title and not title.isspace():
|
||||||
conversation.title = title
|
conversation.title = title
|
||||||
|
|||||||
0
openhands/server/conversation_manager/utils.py
Normal file
0
openhands/server/conversation_manager/utils.py
Normal file
@@ -33,7 +33,6 @@ from openhands.integrations.service_types import (
|
|||||||
ProviderType,
|
ProviderType,
|
||||||
SuggestedTask,
|
SuggestedTask,
|
||||||
)
|
)
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||||
@@ -47,6 +46,7 @@ from openhands.server.services.conversation_service import (
|
|||||||
setup_init_conversation_settings,
|
setup_init_conversation_settings,
|
||||||
)
|
)
|
||||||
from openhands.server.shared import (
|
from openhands.server.shared import (
|
||||||
|
ConversationManagerImpl,
|
||||||
ConversationStoreImpl,
|
ConversationStoreImpl,
|
||||||
config,
|
config,
|
||||||
conversation_manager,
|
conversation_manager,
|
||||||
@@ -364,7 +364,7 @@ async def get_prompt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
prompt_template = generate_prompt_template(stringified_events)
|
prompt_template = generate_prompt_template(stringified_events)
|
||||||
prompt = generate_prompt(llm_config, prompt_template)
|
prompt = generate_prompt(llm_config, prompt_template, conversation_id)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{
|
{
|
||||||
@@ -380,8 +380,9 @@ def generate_prompt_template(events: str) -> str:
|
|||||||
return template.render(events=events)
|
return template.render(events=events)
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
def generate_prompt(
|
||||||
llm = LLM(llm_config)
|
llm_config: LLMConfig, prompt_template: str, conversation_id: str
|
||||||
|
) -> str:
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
'role': 'system',
|
'role': 'system',
|
||||||
@@ -393,8 +394,9 @@ def generate_prompt(llm_config: LLMConfig, prompt_template: str) -> str:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = llm.completion(messages=messages)
|
raw_prompt = ConversationManagerImpl.request_llm_completion(
|
||||||
raw_prompt = response['choices'][0]['message']['content'].strip()
|
'remember_prompt', conversation_id, llm_config, messages
|
||||||
|
)
|
||||||
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
prompt = re.search(r'<update_prompt>(.*?)</update_prompt>', raw_prompt, re.DOTALL)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
|
|||||||
@@ -31,20 +31,60 @@ from openhands.storage.data_models.user_secrets import UserSecrets
|
|||||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||||
|
|
||||||
|
|
||||||
async def create_new_conversation(
|
async def initialize_conversation(
|
||||||
|
user_id: str | None,
|
||||||
|
conversation_id: str | None,
|
||||||
|
selected_repository: str | None,
|
||||||
|
selected_branch: str | None,
|
||||||
|
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||||
|
git_provider: ProviderType | None = None,
|
||||||
|
) -> ConversationMetadata | None:
|
||||||
|
if conversation_id is None:
|
||||||
|
conversation_id = uuid.uuid4().hex
|
||||||
|
|
||||||
|
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||||
|
|
||||||
|
if not await conversation_store.exists(conversation_id):
|
||||||
|
logger.info(
|
||||||
|
f'New conversation ID: {conversation_id}',
|
||||||
|
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_title = get_default_conversation_title(conversation_id)
|
||||||
|
|
||||||
|
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||||
|
convo_metadata = ConversationMetadata(
|
||||||
|
trigger=conversation_trigger,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
title=conversation_title,
|
||||||
|
user_id=user_id,
|
||||||
|
selected_repository=selected_repository,
|
||||||
|
selected_branch=selected_branch,
|
||||||
|
git_provider=git_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
await conversation_store.save_metadata(convo_metadata)
|
||||||
|
return convo_metadata
|
||||||
|
|
||||||
|
try:
|
||||||
|
convo_metadata = await conversation_store.get_metadata(conversation_id)
|
||||||
|
return convo_metadata
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def start_conversation(
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||||
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||||
selected_repository: str | None,
|
|
||||||
selected_branch: str | None,
|
|
||||||
initial_user_msg: str | None,
|
initial_user_msg: str | None,
|
||||||
image_urls: list[str] | None,
|
image_urls: list[str] | None,
|
||||||
replay_json: str | None,
|
replay_json: str | None,
|
||||||
conversation_instructions: str | None = None,
|
conversation_id: str,
|
||||||
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
convo_metadata: ConversationMetadata,
|
||||||
attach_conversation_id: bool = False,
|
conversation_instructions: str | None,
|
||||||
git_provider: ProviderType | None = None,
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
mcp_config: MCPConfig | None = None,
|
mcp_config: MCPConfig | None = None,
|
||||||
) -> AgentLoopInfo:
|
) -> AgentLoopInfo:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -52,7 +92,7 @@ async def create_new_conversation(
|
|||||||
extra={
|
extra={
|
||||||
'signal': 'create_conversation',
|
'signal': 'create_conversation',
|
||||||
'user_id': user_id,
|
'user_id': user_id,
|
||||||
'trigger': conversation_trigger.value,
|
'trigger': convo_metadata.trigger,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info('Loading settings')
|
logger.info('Loading settings')
|
||||||
@@ -79,53 +119,25 @@ async def create_new_conversation(
|
|||||||
raise MissingSettingsError('Settings not found')
|
raise MissingSettingsError('Settings not found')
|
||||||
|
|
||||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||||
session_init_args['selected_repository'] = selected_repository
|
session_init_args['selected_repository'] = convo_metadata.selected_repository
|
||||||
session_init_args['custom_secrets'] = custom_secrets
|
session_init_args['custom_secrets'] = custom_secrets
|
||||||
session_init_args['selected_branch'] = selected_branch
|
session_init_args['selected_branch'] = convo_metadata.selected_branch
|
||||||
session_init_args['git_provider'] = git_provider
|
session_init_args['git_provider'] = convo_metadata.git_provider
|
||||||
session_init_args['conversation_instructions'] = conversation_instructions
|
session_init_args['conversation_instructions'] = conversation_instructions
|
||||||
if mcp_config:
|
if mcp_config:
|
||||||
session_init_args['mcp_config'] = mcp_config
|
session_init_args['mcp_config'] = mcp_config
|
||||||
|
|
||||||
conversation_init_data = ConversationInitData(**session_init_args)
|
conversation_init_data = ConversationInitData(**session_init_args)
|
||||||
|
|
||||||
logger.info('Loading conversation store')
|
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
user_id, conversation_id, conversation_init_data
|
||||||
logger.info('ServerConversation store loaded')
|
)
|
||||||
|
|
||||||
# For nested runtimes, we allow a single conversation id, passed in on container creation
|
|
||||||
if conversation_id is None:
|
|
||||||
conversation_id = uuid.uuid4().hex
|
|
||||||
|
|
||||||
if not await conversation_store.exists(conversation_id):
|
|
||||||
logger.info(
|
|
||||||
f'New conversation ID: {conversation_id}',
|
|
||||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
|
||||||
user_id, conversation_id, conversation_init_data
|
|
||||||
)
|
|
||||||
conversation_title = get_default_conversation_title(conversation_id)
|
|
||||||
|
|
||||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
|
||||||
await conversation_store.save_metadata(
|
|
||||||
ConversationMetadata(
|
|
||||||
trigger=conversation_trigger,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
title=conversation_title,
|
|
||||||
user_id=user_id,
|
|
||||||
selected_repository=selected_repository,
|
|
||||||
selected_branch=selected_branch,
|
|
||||||
git_provider=git_provider,
|
|
||||||
llm_model=conversation_init_data.llm_model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Starting agent loop for conversation {conversation_id}',
|
f'Starting agent loop for conversation {conversation_id}',
|
||||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_message_action = None
|
initial_message_action = None
|
||||||
if initial_user_msg or image_urls:
|
if initial_user_msg or image_urls:
|
||||||
initial_message_action = MessageAction(
|
initial_message_action = MessageAction(
|
||||||
@@ -133,9 +145,6 @@ async def create_new_conversation(
|
|||||||
image_urls=image_urls or [],
|
image_urls=image_urls or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
if attach_conversation_id:
|
|
||||||
logger.warning('Attaching conversation ID is deprecated, skipping process')
|
|
||||||
|
|
||||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||||
conversation_id,
|
conversation_id,
|
||||||
conversation_init_data,
|
conversation_init_data,
|
||||||
@@ -147,6 +156,47 @@ async def create_new_conversation(
|
|||||||
return agent_loop_info
|
return agent_loop_info
|
||||||
|
|
||||||
|
|
||||||
|
async def create_new_conversation(
|
||||||
|
user_id: str | None,
|
||||||
|
git_provider_tokens: PROVIDER_TOKEN_TYPE | None,
|
||||||
|
custom_secrets: CUSTOM_SECRETS_TYPE_WITH_JSON_SCHEMA | None,
|
||||||
|
selected_repository: str | None,
|
||||||
|
selected_branch: str | None,
|
||||||
|
initial_user_msg: str | None,
|
||||||
|
image_urls: list[str] | None,
|
||||||
|
replay_json: str | None,
|
||||||
|
conversation_instructions: str | None = None,
|
||||||
|
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
|
||||||
|
git_provider: ProviderType | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
mcp_config: MCPConfig | None = None,
|
||||||
|
) -> AgentLoopInfo:
|
||||||
|
conversation_metadata = await initialize_conversation(
|
||||||
|
user_id,
|
||||||
|
conversation_id,
|
||||||
|
selected_repository,
|
||||||
|
selected_branch,
|
||||||
|
conversation_trigger,
|
||||||
|
git_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not conversation_metadata:
|
||||||
|
raise Exception('Failed to initialize conversation')
|
||||||
|
|
||||||
|
return await start_conversation(
|
||||||
|
user_id,
|
||||||
|
git_provider_tokens,
|
||||||
|
custom_secrets,
|
||||||
|
initial_user_msg,
|
||||||
|
image_urls,
|
||||||
|
replay_json,
|
||||||
|
conversation_metadata.conversation_id,
|
||||||
|
conversation_metadata,
|
||||||
|
conversation_instructions,
|
||||||
|
mcp_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_provider_tokens_object(
|
def create_provider_tokens_object(
|
||||||
providers_set: list[ProviderType],
|
providers_set: list[ProviderType],
|
||||||
) -> PROVIDER_TOKEN_TYPE:
|
) -> PROVIDER_TOKEN_TYPE:
|
||||||
|
|||||||
77
openhands/server/services/conversation_stats.py
Normal file
77
openhands/server/services/conversation_stats.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.llm.llm_registry import RegistryEvent
|
||||||
|
from openhands.llm.metrics import Metrics
|
||||||
|
from openhands.storage.files import FileStore
|
||||||
|
from openhands.storage.locations import get_conversation_stats_filename
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationStats:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_store: FileStore | None,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
):
|
||||||
|
self.metrics_path = get_conversation_stats_filename(conversation_id, user_id)
|
||||||
|
self.file_store = file_store
|
||||||
|
self.conversation_id = conversation_id
|
||||||
|
self.user_id = user_id
|
||||||
|
|
||||||
|
self._save_lock = Lock()
|
||||||
|
|
||||||
|
self.service_to_metrics: dict[str, Metrics] = {}
|
||||||
|
self.restored_metrics: dict[str, Metrics] = {}
|
||||||
|
|
||||||
|
# Always attempt to restore registry if it exists
|
||||||
|
self.maybe_restore_metrics()
|
||||||
|
|
||||||
|
def save_metrics(self):
|
||||||
|
if not self.file_store:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._save_lock:
|
||||||
|
pickled = pickle.dumps(self.service_to_metrics)
|
||||||
|
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
|
||||||
|
self.file_store.write(self.metrics_path, serialized_metrics)
|
||||||
|
|
||||||
|
def maybe_restore_metrics(self):
|
||||||
|
if not self.file_store or not self.conversation_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoded = self.file_store.read(self.metrics_path)
|
||||||
|
pickled = base64.b64decode(encoded)
|
||||||
|
self.restored_metrics = pickle.loads(pickled)
|
||||||
|
logger.info(f'restored metrics: {self.conversation_id}')
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_combined_metrics(self) -> Metrics:
|
||||||
|
total_metrics = Metrics()
|
||||||
|
for metrics in self.service_to_metrics.values():
|
||||||
|
total_metrics.merge(metrics)
|
||||||
|
|
||||||
|
logger.info(f'metrics by all services: {self.service_to_metrics}')
|
||||||
|
logger.info(f'combined metrics\n\n{total_metrics}')
|
||||||
|
return total_metrics
|
||||||
|
|
||||||
|
def get_metrics_for_service(self, service_id: str) -> Metrics:
|
||||||
|
if service_id not in self.service_to_metrics:
|
||||||
|
raise Exception(f'LLM service does not exist {service_id}')
|
||||||
|
|
||||||
|
return self.service_to_metrics[service_id]
|
||||||
|
|
||||||
|
def register_llm(self, event: RegistryEvent):
|
||||||
|
# Listen for llm creations and track their metrics
|
||||||
|
llm = event.llm
|
||||||
|
service_id = event.service_id
|
||||||
|
|
||||||
|
if service_id in self.restored_metrics:
|
||||||
|
llm.metrics = self.restored_metrics[service_id].copy()
|
||||||
|
del self.restored_metrics[service_id]
|
||||||
|
|
||||||
|
self.service_to_metrics[service_id] = llm.metrics
|
||||||
@@ -21,6 +21,7 @@ from openhands.integrations.provider import (
|
|||||||
PROVIDER_TOKEN_TYPE,
|
PROVIDER_TOKEN_TYPE,
|
||||||
ProviderHandler,
|
ProviderHandler,
|
||||||
)
|
)
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.mcp import add_mcp_tools_to_agent
|
from openhands.mcp import add_mcp_tools_to_agent
|
||||||
from openhands.memory.memory import Memory
|
from openhands.memory.memory import Memory
|
||||||
from openhands.microagent.microagent import BaseMicroagent
|
from openhands.microagent.microagent import BaseMicroagent
|
||||||
@@ -29,6 +30,7 @@ from openhands.runtime.base import Runtime
|
|||||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
from openhands.security import SecurityAnalyzer, options
|
from openhands.security import SecurityAnalyzer, options
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||||
from openhands.storage.files import FileStore
|
from openhands.storage.files import FileStore
|
||||||
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
|
from openhands.utils.async_utils import EXECUTOR, call_sync_from_async
|
||||||
@@ -48,6 +50,7 @@ class AgentSession:
|
|||||||
sid: str
|
sid: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
event_stream: EventStream
|
event_stream: EventStream
|
||||||
|
llm_registry: LLMRegistry
|
||||||
file_store: FileStore
|
file_store: FileStore
|
||||||
controller: AgentController | None = None
|
controller: AgentController | None = None
|
||||||
runtime: Runtime | None = None
|
runtime: Runtime | None = None
|
||||||
@@ -63,6 +66,8 @@ class AgentSession:
|
|||||||
self,
|
self,
|
||||||
sid: str,
|
sid: str,
|
||||||
file_store: FileStore,
|
file_store: FileStore,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
|
convo_stats: ConversationStats,
|
||||||
status_callback: Callable | None = None,
|
status_callback: Callable | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -80,6 +85,8 @@ class AgentSession:
|
|||||||
self.logger = OpenHandsLoggerAdapter(
|
self.logger = OpenHandsLoggerAdapter(
|
||||||
extra={'session_id': sid, 'user_id': user_id}
|
extra={'session_id': sid, 'user_id': user_id}
|
||||||
)
|
)
|
||||||
|
self.llm_registry = llm_registry
|
||||||
|
self.convo_stats = convo_stats
|
||||||
|
|
||||||
async def start(
|
async def start(
|
||||||
self,
|
self,
|
||||||
@@ -340,6 +347,7 @@ class AgentSession:
|
|||||||
self.runtime = runtime_cls(
|
self.runtime = runtime_cls(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=self.event_stream,
|
event_stream=self.event_stream,
|
||||||
|
llm_registry=self.llm_registry,
|
||||||
sid=self.sid,
|
sid=self.sid,
|
||||||
plugins=agent.sandbox_plugins,
|
plugins=agent.sandbox_plugins,
|
||||||
status_callback=self._status_callback,
|
status_callback=self._status_callback,
|
||||||
@@ -360,6 +368,7 @@ class AgentSession:
|
|||||||
self.runtime = runtime_cls(
|
self.runtime = runtime_cls(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=self.event_stream,
|
event_stream=self.event_stream,
|
||||||
|
llm_registry=self.llm_registry,
|
||||||
sid=self.sid,
|
sid=self.sid,
|
||||||
plugins=agent.sandbox_plugins,
|
plugins=agent.sandbox_plugins,
|
||||||
status_callback=self._status_callback,
|
status_callback=self._status_callback,
|
||||||
@@ -441,6 +450,7 @@ class AgentSession:
|
|||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
file_store=self.file_store,
|
file_store=self.file_store,
|
||||||
event_stream=self.event_stream,
|
event_stream=self.event_stream,
|
||||||
|
convo_stats=self.convo_stats,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
iteration_delta=int(max_iterations),
|
iteration_delta=int(max_iterations),
|
||||||
budget_per_task_delta=max_budget_per_task,
|
budget_per_task_delta=max_budget_per_task,
|
||||||
@@ -490,6 +500,15 @@ class AgentSession:
|
|||||||
)
|
)
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
def get_state(self) -> AgentState | None:
|
||||||
|
controller = self.controller
|
||||||
|
if controller:
|
||||||
|
return controller.state.agent_state
|
||||||
|
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
||||||
|
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
||||||
|
return AgentState.ERROR
|
||||||
|
return None
|
||||||
|
|
||||||
def _maybe_restore_state(self) -> State | None:
|
def _maybe_restore_state(self) -> State | None:
|
||||||
"""Helper method to handle state restore logic."""
|
"""Helper method to handle state restore logic."""
|
||||||
restored_state = None
|
restored_state = None
|
||||||
@@ -510,14 +529,5 @@ class AgentSession:
|
|||||||
self.logger.debug('No events found, no state to restore')
|
self.logger.debug('No events found, no state to restore')
|
||||||
return restored_state
|
return restored_state
|
||||||
|
|
||||||
def get_state(self) -> AgentState | None:
|
|
||||||
controller = self.controller
|
|
||||||
if controller:
|
|
||||||
return controller.state.agent_state
|
|
||||||
if time.time() > self._started_at + WAIT_TIME_BEFORE_CLOSE:
|
|
||||||
# If 5 minutes have elapsed and we still don't have a controller, something has gone wrong
|
|
||||||
return AgentState.ERROR
|
|
||||||
return None
|
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
return self._closed
|
return self._closed
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
|
|
||||||
from openhands.core.config import OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.security import SecurityAnalyzer, options
|
from openhands.security import SecurityAnalyzer, options
|
||||||
@@ -45,6 +46,7 @@ class ServerConversation:
|
|||||||
else:
|
else:
|
||||||
runtime_cls = get_runtime_cls(self.config.runtime)
|
runtime_cls = get_runtime_cls(self.config.runtime)
|
||||||
runtime = runtime_cls(
|
runtime = runtime_cls(
|
||||||
|
llm_registry=LLMRegistry(self.config),
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=self.event_stream,
|
event_stream=self.event_stream,
|
||||||
sid=self.sid,
|
sid=self.sid,
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
|
||||||
from logging import LoggerAdapter
|
from logging import LoggerAdapter
|
||||||
|
|
||||||
import socketio
|
import socketio
|
||||||
@@ -28,9 +27,10 @@ from openhands.events.observation.agent import RecallObservation
|
|||||||
from openhands.events.observation.error import ErrorObservation
|
from openhands.events.observation.error import ErrorObservation
|
||||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||||
from openhands.events.stream import EventStreamSubscriber
|
from openhands.events.stream import EventStreamSubscriber
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
from openhands.server.constants import ROOM_KEY
|
from openhands.server.constants import ROOM_KEY
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.server.session.agent_session import AgentSession
|
from openhands.server.session.agent_session import AgentSession
|
||||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||||
from openhands.storage.data_models.settings import Settings
|
from openhands.storage.data_models.settings import Settings
|
||||||
@@ -45,6 +45,7 @@ class Session:
|
|||||||
agent_session: AgentSession
|
agent_session: AgentSession
|
||||||
loop: asyncio.AbstractEventLoop
|
loop: asyncio.AbstractEventLoop
|
||||||
config: OpenHandsConfig
|
config: OpenHandsConfig
|
||||||
|
llm_registry: LLMRegistry
|
||||||
file_store: FileStore
|
file_store: FileStore
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
logger: LoggerAdapter
|
logger: LoggerAdapter
|
||||||
@@ -53,6 +54,8 @@ class Session:
|
|||||||
self,
|
self,
|
||||||
sid: str,
|
sid: str,
|
||||||
config: OpenHandsConfig,
|
config: OpenHandsConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
|
convo_stats: ConversationStats,
|
||||||
file_store: FileStore,
|
file_store: FileStore,
|
||||||
sio: socketio.AsyncServer | None,
|
sio: socketio.AsyncServer | None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
@@ -62,17 +65,21 @@ class Session:
|
|||||||
self.last_active_ts = int(time.time())
|
self.last_active_ts = int(time.time())
|
||||||
self.file_store = file_store
|
self.file_store = file_store
|
||||||
self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid})
|
self.logger = OpenHandsLoggerAdapter(extra={'session_id': sid})
|
||||||
|
self.llm_registry = llm_registry
|
||||||
|
self.convo_stats = convo_stats
|
||||||
self.agent_session = AgentSession(
|
self.agent_session = AgentSession(
|
||||||
sid,
|
sid,
|
||||||
file_store,
|
file_store,
|
||||||
|
llm_registry=self.llm_registry,
|
||||||
|
convo_stats=convo_stats,
|
||||||
status_callback=self.queue_status_message,
|
status_callback=self.queue_status_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
self.agent_session.event_stream.subscribe(
|
self.agent_session.event_stream.subscribe(
|
||||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||||
)
|
)
|
||||||
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
self.config = config
|
||||||
self.config = deepcopy(config)
|
|
||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||||
|
|
||||||
@@ -140,13 +147,6 @@ class Session:
|
|||||||
else self.config.max_budget_per_task
|
else self.config.max_budget_per_task
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is a shallow copy of the default LLM config, so changes here will
|
|
||||||
# persist if we retrieve the default LLM config again when constructing
|
|
||||||
# the agent
|
|
||||||
default_llm_config = self.config.get_llm_config()
|
|
||||||
default_llm_config.model = settings.llm_model or ''
|
|
||||||
default_llm_config.api_key = settings.llm_api_key
|
|
||||||
default_llm_config.base_url = settings.llm_base_url
|
|
||||||
self.config.search_api_key = settings.search_api_key
|
self.config.search_api_key = settings.search_api_key
|
||||||
if settings.sandbox_api_key:
|
if settings.sandbox_api_key:
|
||||||
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
|
self.config.sandbox.api_key = settings.sandbox_api_key.get_secret_value()
|
||||||
@@ -181,10 +181,9 @@ class Session:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: override other LLM config & agent config groups (#2075)
|
# TODO: override other LLM config & agent config groups (#2075)
|
||||||
|
|
||||||
llm = self._create_llm(agent_cls)
|
|
||||||
agent_config = self.config.get_agent_config(agent_cls)
|
agent_config = self.config.get_agent_config(agent_cls)
|
||||||
|
agent_name = agent_cls if agent_cls is not None else 'agent'
|
||||||
|
llm_config = self.config.get_llm_config_from_agent(agent_name)
|
||||||
if settings.enable_default_condenser:
|
if settings.enable_default_condenser:
|
||||||
# Default condenser chains three condensers together:
|
# Default condenser chains three condensers together:
|
||||||
# 1. a conversation window condenser that handles explicit
|
# 1. a conversation window condenser that handles explicit
|
||||||
@@ -200,7 +199,7 @@ class Session:
|
|||||||
ConversationWindowCondenserConfig(),
|
ConversationWindowCondenserConfig(),
|
||||||
BrowserOutputCondenserConfig(attention_window=2),
|
BrowserOutputCondenserConfig(attention_window=2),
|
||||||
LLMSummarizingCondenserConfig(
|
LLMSummarizingCondenserConfig(
|
||||||
llm_config=llm.config, keep_first=4, max_size=120
|
llm_config=llm_config, keep_first=4, max_size=120
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -208,12 +207,14 @@ class Session:
|
|||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'Enabling pipeline condenser with:'
|
f'Enabling pipeline condenser with:'
|
||||||
f' browser_output_masking(attention_window=2), '
|
f' browser_output_masking(attention_window=2), '
|
||||||
f' llm(model="{llm.config.model}", '
|
f' llm(model="{llm_config.model}", '
|
||||||
f' base_url="{llm.config.base_url}", '
|
f' base_url="{llm_config.base_url}", '
|
||||||
f' keep_first=4, max_size=80)'
|
f' keep_first=4, max_size=80)'
|
||||||
)
|
)
|
||||||
agent_config.condenser = default_condenser_config
|
agent_config.condenser = default_condenser_config
|
||||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
agent = Agent.get_cls(agent_cls)(agent_config, self.llm_registry)
|
||||||
|
|
||||||
|
self.llm_registry.retry_listner = self._notify_on_llm_retry
|
||||||
|
|
||||||
git_provider_tokens = None
|
git_provider_tokens = None
|
||||||
selected_repository = None
|
selected_repository = None
|
||||||
@@ -269,14 +270,6 @@ class Session:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def _create_llm(self, agent_cls: str | None) -> LLM:
|
|
||||||
"""Initialize LLM, extracted for testing."""
|
|
||||||
agent_name = agent_cls if agent_cls is not None else 'agent'
|
|
||||||
return LLM(
|
|
||||||
config=self.config.get_llm_config_from_agent(agent_name),
|
|
||||||
retry_listener=self._notify_on_llm_retry,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
def _notify_on_llm_retry(self, retries: int, max: int) -> None:
|
||||||
self.queue_status_message(
|
self.queue_status_message(
|
||||||
'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}'
|
'info', RuntimeStatus.LLM_RETRY, f'Retrying LLM request, {retries} / {max}'
|
||||||
|
|||||||
@@ -30,5 +30,13 @@ def get_conversation_agent_state_filename(sid: str, user_id: str | None = None)
|
|||||||
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'
|
return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'
|
||||||
|
|
||||||
|
|
||||||
|
def get_conversation_llm_registry_filename(sid: str, user_id: str | None = None) -> str:
|
||||||
|
return f'{get_conversation_dir(sid, user_id)}llm_registry.json'
|
||||||
|
|
||||||
|
|
||||||
|
def get_conversation_stats_filename(sid: str, user_id: str | None = None) -> str:
|
||||||
|
return f'{get_conversation_dir(sid, user_id)}convo_stats.json'
|
||||||
|
|
||||||
|
|
||||||
def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str:
|
def get_experiment_config_filename(sid: str, user_id: str | None = None) -> str:
|
||||||
return f'{get_conversation_dir(sid, user_id)}exp_config.json'
|
return f'{get_conversation_dir(sid, user_id)}exp_config.json'
|
||||||
|
|||||||
@@ -7,13 +7,16 @@ from openhands.core.logger import openhands_logger as logger
|
|||||||
from openhands.events.action.message import MessageAction
|
from openhands.events.action.message import MessageAction
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.event_store import EventStore
|
from openhands.events.event_store import EventStore
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.storage.data_models.settings import Settings
|
from openhands.storage.data_models.settings import Settings
|
||||||
from openhands.storage.files import FileStore
|
from openhands.storage.files import FileStore
|
||||||
|
|
||||||
|
|
||||||
async def generate_conversation_title(
|
async def generate_conversation_title(
|
||||||
message: str, llm_config: LLMConfig, max_length: int = 50
|
message: str,
|
||||||
|
llm_config: LLMConfig,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
|
max_length: int = 50,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Generate a concise title for a conversation based on the first user message.
|
"""Generate a concise title for a conversation based on the first user message.
|
||||||
|
|
||||||
@@ -35,8 +38,6 @@ async def generate_conversation_title(
|
|||||||
truncated_message = message
|
truncated_message = message
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = LLM(llm_config)
|
|
||||||
|
|
||||||
# Create a simple prompt for the LLM to generate a title
|
# Create a simple prompt for the LLM to generate a title
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -49,8 +50,9 @@ async def generate_conversation_title(
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = llm.completion(messages=messages)
|
title = llm_registry.request_extraneous_completion(
|
||||||
title = response.choices[0].message.content.strip()
|
'convo_title_creator', llm_config, messages
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure the title isn't too long
|
# Ensure the title isn't too long
|
||||||
if len(title) > max_length:
|
if len(title) > max_length:
|
||||||
@@ -75,7 +77,11 @@ def get_default_conversation_title(conversation_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def auto_generate_title(
|
async def auto_generate_title(
|
||||||
conversation_id: str, user_id: str | None, file_store: FileStore, settings: Settings
|
conversation_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
file_store: FileStore,
|
||||||
|
settings: Settings,
|
||||||
|
llm_registry: LLMRegistry,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Auto-generate a title for a conversation based on the first user message.
|
"""Auto-generate a title for a conversation based on the first user message.
|
||||||
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
Uses LLM-based title generation if available, otherwise falls back to a simple truncation.
|
||||||
@@ -116,7 +122,7 @@ async def auto_generate_title(
|
|||||||
|
|
||||||
# Try to generate title using LLM
|
# Try to generate title using LLM
|
||||||
llm_title = await generate_conversation_title(
|
llm_title = await generate_conversation_title(
|
||||||
first_user_message, llm_config
|
first_user_message, llm_config, llm_registry
|
||||||
)
|
)
|
||||||
if llm_title:
|
if llm_title:
|
||||||
logger.info(f'Generated title using LLM: {llm_title}')
|
logger.info(f'Generated title using LLM: {llm_title}')
|
||||||
|
|||||||
37
openhands/utils/utils.py
Normal file
37
openhands/utils/utils.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
|
from openhands.storage import get_file_store
|
||||||
|
from openhands.storage.data_models.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def setup_llm_config(config: OpenHandsConfig, settings: Settings) -> OpenHandsConfig:
|
||||||
|
# Copying this means that when we update variables they are not applied to the shared global configuration!
|
||||||
|
config = deepcopy(config)
|
||||||
|
|
||||||
|
llm_config = config.get_llm_config()
|
||||||
|
llm_config.model = settings.llm_model or ''
|
||||||
|
llm_config.api_key = settings.llm_api_key
|
||||||
|
llm_config.base_url = settings.llm_base_url
|
||||||
|
config.set_llm_config(llm_config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_registry_and_convo_stats(
|
||||||
|
config: OpenHandsConfig,
|
||||||
|
sid: str,
|
||||||
|
user_id: str | None,
|
||||||
|
user_settings: Settings | None = None,
|
||||||
|
) -> tuple[LLMRegistry, ConversationStats, OpenHandsConfig]:
|
||||||
|
user_config = config
|
||||||
|
if user_settings:
|
||||||
|
user_config = setup_llm_config(config, user_settings)
|
||||||
|
|
||||||
|
agent_cls = user_settings.agent if user_settings else None
|
||||||
|
llm_registry = LLMRegistry(user_config, agent_cls)
|
||||||
|
file_store = get_file_store(user_config.file_store, user_config.file_store_path)
|
||||||
|
convo_stats = ConversationStats(file_store, sid, user_id)
|
||||||
|
llm_registry.subscribe(convo_stats.register_llm)
|
||||||
|
return llm_registry, convo_stats, user_config
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
addopts = -p no:warnings
|
addopts = -p no:warnings
|
||||||
|
asyncio_mode = auto
|
||||||
asyncio_default_fixture_loop_scope = function
|
asyncio_default_fixture_loop_scope = function
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -10,6 +10,7 @@ from pytest import TempPathFactory
|
|||||||
from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config
|
from openhands.core.config import MCPConfig, OpenHandsConfig, load_openhands_config
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||||
@@ -268,9 +269,13 @@ def _load_runtime(
|
|||||||
)
|
)
|
||||||
event_stream = EventStream(sid, file_store)
|
event_stream = EventStream(sid, file_store)
|
||||||
|
|
||||||
|
# Create a LLMRegistry instance for the runtime
|
||||||
|
llm_registry = LLMRegistry(config=OpenHandsConfig())
|
||||||
|
|
||||||
runtime = runtime_cls(
|
runtime = runtime_cls(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
llm_registry=llm_registry,
|
||||||
sid=sid,
|
sid=sid,
|
||||||
plugins=plugins,
|
plugins=plugins,
|
||||||
)
|
)
|
||||||
|
|||||||
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
@@ -20,7 +20,7 @@ def test_llm():
|
|||||||
|
|
||||||
def _get_llm(type_: type[LLM]):
|
def _get_llm(type_: type[LLM]):
|
||||||
with _patch_http():
|
with _patch_http():
|
||||||
return type_(config=config.get_llm_config())
|
return type_(config=config.get_llm_config(), service_id='test_service')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def default_config():
|
|||||||
|
|
||||||
|
|
||||||
def test_llm_init_with_default_config(default_config):
|
def test_llm_init_with_default_config(default_config):
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
assert llm.config.model == 'gpt-4o'
|
assert llm.config.model == 'gpt-4o'
|
||||||
assert llm.config.api_key.get_secret_value() == 'test_key'
|
assert llm.config.api_key.get_secret_value() == 'test_key'
|
||||||
assert isinstance(llm.metrics, Metrics)
|
assert isinstance(llm.metrics, Metrics)
|
||||||
@@ -129,7 +129,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
|||||||
'max_input_tokens': 8000,
|
'max_input_tokens': 8000,
|
||||||
'max_output_tokens': 2000,
|
'max_output_tokens': 2000,
|
||||||
}
|
}
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
llm.init_model_info()
|
llm.init_model_info()
|
||||||
assert llm.config.max_input_tokens == 8000
|
assert llm.config.max_input_tokens == 8000
|
||||||
assert llm.config.max_output_tokens == 2000
|
assert llm.config.max_output_tokens == 2000
|
||||||
@@ -138,7 +138,7 @@ def test_llm_init_with_model_info(mock_get_model_info, default_config):
|
|||||||
@patch('openhands.llm.llm.litellm.get_model_info')
|
@patch('openhands.llm.llm.litellm.get_model_info')
|
||||||
def test_llm_init_without_model_info(mock_get_model_info, default_config):
|
def test_llm_init_without_model_info(mock_get_model_info, default_config):
|
||||||
mock_get_model_info.side_effect = Exception('Model info not available')
|
mock_get_model_info.side_effect = Exception('Model info not available')
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
llm.init_model_info()
|
llm.init_model_info()
|
||||||
assert llm.config.max_input_tokens is None
|
assert llm.config.max_input_tokens is None
|
||||||
assert llm.config.max_output_tokens is None
|
assert llm.config.max_output_tokens is None
|
||||||
@@ -154,7 +154,7 @@ def test_llm_init_with_custom_config():
|
|||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
)
|
)
|
||||||
llm = LLM(custom_config)
|
llm = LLM(custom_config, service_id='test-service')
|
||||||
assert llm.config.model == 'custom-model'
|
assert llm.config.model == 'custom-model'
|
||||||
assert llm.config.api_key.get_secret_value() == 'custom_key'
|
assert llm.config.api_key.get_secret_value() == 'custom_key'
|
||||||
assert llm.config.max_input_tokens == 5000
|
assert llm.config.max_input_tokens == 5000
|
||||||
@@ -168,7 +168,7 @@ def test_llm_init_with_custom_config():
|
|||||||
def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
||||||
# Create a config with top_k set
|
# Create a config with top_k set
|
||||||
config_with_top_k = LLMConfig(top_k=50)
|
config_with_top_k = LLMConfig(top_k=50)
|
||||||
llm = LLM(config_with_top_k)
|
llm = LLM(config_with_top_k, service_id='test-service')
|
||||||
|
|
||||||
# Define a side effect function to check top_k
|
# Define a side effect function to check top_k
|
||||||
def side_effect(*args, **kwargs):
|
def side_effect(*args, **kwargs):
|
||||||
@@ -186,7 +186,7 @@ def test_llm_top_k_in_completion_when_set(mock_litellm_completion):
|
|||||||
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
||||||
# Create a config with top_k set to None
|
# Create a config with top_k set to None
|
||||||
config_without_top_k = LLMConfig(top_k=None)
|
config_without_top_k = LLMConfig(top_k=None)
|
||||||
llm = LLM(config_without_top_k)
|
llm = LLM(config_without_top_k, service_id='test-service')
|
||||||
|
|
||||||
# Define a side effect function to check top_k
|
# Define a side effect function to check top_k
|
||||||
def side_effect(*args, **kwargs):
|
def side_effect(*args, **kwargs):
|
||||||
@@ -202,7 +202,7 @@ def test_llm_top_k_not_in_completion_when_none(mock_litellm_completion):
|
|||||||
def test_llm_init_with_metrics():
|
def test_llm_init_with_metrics():
|
||||||
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||||
metrics = Metrics()
|
metrics = Metrics()
|
||||||
llm = LLM(config, metrics=metrics)
|
llm = LLM(config, metrics=metrics, service_id='test-service')
|
||||||
assert llm.metrics is metrics
|
assert llm.metrics is metrics
|
||||||
assert (
|
assert (
|
||||||
llm.metrics.model_name == 'default'
|
llm.metrics.model_name == 'default'
|
||||||
@@ -224,7 +224,7 @@ def test_response_latency_tracking(mock_time, mock_litellm_completion):
|
|||||||
|
|
||||||
# Create LLM instance and make a completion call
|
# Create LLM instance and make a completion call
|
||||||
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
response = llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||||
|
|
||||||
# Verify the response latency was tracked correctly
|
# Verify the response latency was tracked correctly
|
||||||
@@ -257,7 +257,7 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config):
|
|||||||
'max_input_tokens': 7000,
|
'max_input_tokens': 7000,
|
||||||
'max_output_tokens': 1500,
|
'max_output_tokens': 1500,
|
||||||
}
|
}
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
llm.init_model_info()
|
llm.init_model_info()
|
||||||
assert llm.config.max_input_tokens == 7000
|
assert llm.config.max_input_tokens == 7000
|
||||||
assert llm.config.max_output_tokens == 1500
|
assert llm.config.max_output_tokens == 1500
|
||||||
@@ -280,7 +280,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
|
|||||||
default_config.model = (
|
default_config.model = (
|
||||||
'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS
|
'custom-model' # Use a model not in FUNCTION_CALLING_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
llm.completion(
|
llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
tools=[
|
tools=[
|
||||||
@@ -292,7 +292,7 @@ def test_stop_parameter_handling(mock_litellm_completion, default_config):
|
|||||||
|
|
||||||
# Test with Grok-4 model that doesn't support stop parameter
|
# Test with Grok-4 model that doesn't support stop parameter
|
||||||
default_config.model = 'xai/grok-4-0709'
|
default_config.model = 'xai/grok-4-0709'
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
llm.completion(
|
llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
tools=[
|
tools=[
|
||||||
@@ -314,7 +314,7 @@ def test_completion_with_mocked_logger(
|
|||||||
'choices': [{'message': {'content': 'Test response'}}]
|
'choices': [{'message': {'content': 'Test response'}}]
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -345,7 +345,7 @@ def test_completion_retries(
|
|||||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||||
]
|
]
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -365,7 +365,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
|
|||||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||||
]
|
]
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -387,7 +387,7 @@ def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config
|
|||||||
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
|
def test_completion_operation_cancelled(mock_litellm_completion, default_config):
|
||||||
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
|
mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled')
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
with pytest.raises(OperationCancelled):
|
with pytest.raises(OperationCancelled):
|
||||||
llm.completion(
|
llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
@@ -404,7 +404,7 @@ def test_completion_keyboard_interrupt(mock_litellm_completion, default_config):
|
|||||||
|
|
||||||
mock_litellm_completion.side_effect = side_effect
|
mock_litellm_completion.side_effect = side_effect
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
with pytest.raises(OperationCancelled):
|
with pytest.raises(OperationCancelled):
|
||||||
try:
|
try:
|
||||||
llm.completion(
|
llm.completion(
|
||||||
@@ -428,7 +428,7 @@ def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_
|
|||||||
|
|
||||||
mock_litellm_completion.side_effect = side_effect
|
mock_litellm_completion.side_effect = side_effect
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
result = llm.completion(
|
result = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -469,7 +469,7 @@ def test_completion_retry_with_llm_no_response_error_zero_temp(
|
|||||||
mock_litellm_completion.side_effect = side_effect
|
mock_litellm_completion.side_effect = side_effect
|
||||||
|
|
||||||
# Create LLM instance and make a completion call
|
# Create LLM instance and make a completion call
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -509,7 +509,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp(
|
|||||||
'LLM did not return a response'
|
'LLM did not return a response'
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
with pytest.raises(LLMNoResponseError):
|
with pytest.raises(LLMNoResponseError):
|
||||||
llm.completion(
|
llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
@@ -575,7 +575,7 @@ def test_gemini_25_pro_function_calling(mock_httpx_get, mock_get_model_info):
|
|||||||
|
|
||||||
for model_name, expected_support in test_cases:
|
for model_name, expected_support in test_cases:
|
||||||
config = LLMConfig(model=model_name, api_key='test_key')
|
config = LLMConfig(model=model_name, api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
assert llm.is_function_calling_active() == expected_support, (
|
assert llm.is_function_calling_active() == expected_support, (
|
||||||
f'Expected function calling support to be {expected_support} for model {model_name}'
|
f'Expected function calling support to be {expected_support} for model {model_name}'
|
||||||
@@ -617,7 +617,7 @@ def test_completion_retry_with_llm_no_response_error_nonzero_temp_successful_ret
|
|||||||
mock_litellm_completion.side_effect = side_effect
|
mock_litellm_completion.side_effect = side_effect
|
||||||
|
|
||||||
# Create LLM instance and make a completion call with non-zero temperature
|
# Create LLM instance and make a completion call with non-zero temperature
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -677,7 +677,7 @@ def test_completion_retry_with_llm_no_response_error_successful_retry(
|
|||||||
mock_litellm_completion.side_effect = side_effect
|
mock_litellm_completion.side_effect = side_effect
|
||||||
|
|
||||||
# Create LLM instance and make a completion call with explicit temperature=0
|
# Create LLM instance and make a completion call with explicit temperature=0
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -709,7 +709,7 @@ def test_completion_with_litellm_mock(mock_litellm_completion, default_config):
|
|||||||
}
|
}
|
||||||
mock_litellm_completion.return_value = mock_response
|
mock_litellm_completion.return_value = mock_response
|
||||||
|
|
||||||
test_llm = LLM(config=default_config)
|
test_llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = test_llm.completion(
|
response = test_llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -743,7 +743,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Initialize LLM and call completion
|
# Initialize LLM and call completion
|
||||||
llm = LLM(config=gemini_config)
|
llm = LLM(config=gemini_config, service_id='test-service')
|
||||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||||
|
|
||||||
# Verify that litellm_completion was called with the 'thinking' parameter
|
# Verify that litellm_completion was called with the 'thinking' parameter
|
||||||
@@ -762,7 +762,7 @@ def test_llm_gemini_thinking_parameter(mock_litellm_completion, default_config):
|
|||||||
@patch('openhands.llm.llm.litellm.token_counter')
|
@patch('openhands.llm.llm.litellm.token_counter')
|
||||||
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
||||||
mock_token_counter.return_value = 42
|
mock_token_counter.return_value = 42
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||||
|
|
||||||
token_count = llm.get_token_count(messages)
|
token_count = llm.get_token_count(messages)
|
||||||
@@ -777,7 +777,7 @@ def test_get_token_count_with_dict_messages(mock_token_counter, default_config):
|
|||||||
def test_get_token_count_with_message_objects(
|
def test_get_token_count_with_message_objects(
|
||||||
mock_token_counter, default_config, mock_logger
|
mock_token_counter, default_config, mock_logger
|
||||||
):
|
):
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
|
|
||||||
# Create a Message object and its equivalent dict
|
# Create a Message object and its equivalent dict
|
||||||
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
|
message_obj = Message(role='user', content=[TextContent(text='Hello!')])
|
||||||
@@ -806,7 +806,7 @@ def test_get_token_count_with_custom_tokenizer(
|
|||||||
|
|
||||||
config = copy.deepcopy(default_config)
|
config = copy.deepcopy(default_config)
|
||||||
config.custom_tokenizer = 'custom/tokenizer'
|
config.custom_tokenizer = 'custom/tokenizer'
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||||
|
|
||||||
token_count = llm.get_token_count(messages)
|
token_count = llm.get_token_count(messages)
|
||||||
@@ -823,7 +823,7 @@ def test_get_token_count_error_handling(
|
|||||||
mock_token_counter, default_config, mock_logger
|
mock_token_counter, default_config, mock_logger
|
||||||
):
|
):
|
||||||
mock_token_counter.side_effect = Exception('Token counting failed')
|
mock_token_counter.side_effect = Exception('Token counting failed')
|
||||||
llm = LLM(default_config)
|
llm = LLM(default_config, service_id='test-service')
|
||||||
messages = [{'role': 'user', 'content': 'Hello!'}]
|
messages = [{'role': 'user', 'content': 'Hello!'}]
|
||||||
|
|
||||||
token_count = llm.get_token_count(messages)
|
token_count = llm.get_token_count(messages)
|
||||||
@@ -865,7 +865,7 @@ def test_llm_token_usage(mock_litellm_completion, default_config):
|
|||||||
# We'll make mock_litellm_completion return these responses in sequence
|
# We'll make mock_litellm_completion return these responses in sequence
|
||||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
|
|
||||||
# First call
|
# First call
|
||||||
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
|
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])
|
||||||
@@ -924,7 +924,7 @@ def test_accumulated_token_usage(mock_litellm_completion, default_config):
|
|||||||
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]
|
||||||
|
|
||||||
# Create LLM instance
|
# Create LLM instance
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
|
|
||||||
# First call
|
# First call
|
||||||
llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
|
llm.completion(messages=[{'role': 'user', 'content': 'First message'}])
|
||||||
@@ -980,7 +980,7 @@ def test_completion_with_log_completions(mock_litellm_completion, default_config
|
|||||||
}
|
}
|
||||||
mock_litellm_completion.return_value = mock_response
|
mock_litellm_completion.return_value = mock_response
|
||||||
|
|
||||||
test_llm = LLM(config=default_config)
|
test_llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = test_llm.completion(
|
response = test_llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -1006,7 +1006,7 @@ def test_llm_base_url_auto_protocol_patch(mock_get):
|
|||||||
mock_get.return_value.status_code = 200
|
mock_get.return_value.status_code = 200
|
||||||
mock_get.return_value.json.return_value = {'model': 'fake'}
|
mock_get.return_value.json.return_value = {'model': 'fake'}
|
||||||
|
|
||||||
llm = LLM(config=config)
|
llm = LLM(config=config, service_id='test-service')
|
||||||
llm.init_model_info()
|
llm.init_model_info()
|
||||||
|
|
||||||
called_url = mock_get.call_args[0][0]
|
called_url = mock_get.call_args[0][0]
|
||||||
@@ -1020,7 +1020,7 @@ def test_unknown_model_token_limits():
|
|||||||
"""Test that models without known token limits get None for both max_output_tokens and max_input_tokens."""
|
"""Test that models without known token limits get None for both max_output_tokens and max_input_tokens."""
|
||||||
# Create LLM instance with a non-existent model to avoid litellm having model info for it
|
# Create LLM instance with a non-existent model to avoid litellm having model info for it
|
||||||
config = LLMConfig(model='non-existent-model', api_key='test_key')
|
config = LLMConfig(model='non-existent-model', api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
# Verify max_output_tokens and max_input_tokens are initialized to None (default value)
|
# Verify max_output_tokens and max_input_tokens are initialized to None (default value)
|
||||||
assert llm.config.max_output_tokens is None
|
assert llm.config.max_output_tokens is None
|
||||||
@@ -1031,7 +1031,7 @@ def test_max_tokens_from_model_info():
|
|||||||
"""Test that max_output_tokens and max_input_tokens are correctly initialized from model info."""
|
"""Test that max_output_tokens and max_input_tokens are correctly initialized from model info."""
|
||||||
# Create LLM instance with GPT-4 model which has known token limits
|
# Create LLM instance with GPT-4 model which has known token limits
|
||||||
config = LLMConfig(model='gpt-4', api_key='test_key')
|
config = LLMConfig(model='gpt-4', api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
# GPT-4 has specific token limits
|
# GPT-4 has specific token limits
|
||||||
# These are the expected values from litellm
|
# These are the expected values from litellm
|
||||||
@@ -1043,7 +1043,7 @@ def test_claude_3_7_sonnet_max_output_tokens():
|
|||||||
"""Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens."""
|
"""Test that Claude 3.7 Sonnet models get the special 64000 max_output_tokens value and default max_input_tokens."""
|
||||||
# Create LLM instance with Claude 3.7 Sonnet model
|
# Create LLM instance with Claude 3.7 Sonnet model
|
||||||
config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key')
|
config = LLMConfig(model='claude-3-7-sonnet', api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
# Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet
|
# Verify max_output_tokens is set to 64000 for Claude 3.7 Sonnet
|
||||||
assert llm.config.max_output_tokens == 64000
|
assert llm.config.max_output_tokens == 64000
|
||||||
@@ -1055,7 +1055,7 @@ def test_claude_sonnet_4_max_output_tokens():
|
|||||||
"""Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values."""
|
"""Test that Claude Sonnet 4 models get the correct max_output_tokens and max_input_tokens values."""
|
||||||
# Create LLM instance with a Claude Sonnet 4 model
|
# Create LLM instance with a Claude Sonnet 4 model
|
||||||
config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key')
|
config = LLMConfig(model='claude-sonnet-4-20250514', api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
# Verify max_output_tokens is set to the expected value
|
# Verify max_output_tokens is set to the expected value
|
||||||
assert llm.config.max_output_tokens == 64000
|
assert llm.config.max_output_tokens == 64000
|
||||||
@@ -1068,7 +1068,7 @@ def test_sambanova_deepseek_model_max_output_tokens():
|
|||||||
"""Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value."""
|
"""Test that SambaNova DeepSeek-V3-0324 model gets the correct max_output_tokens value."""
|
||||||
# Create LLM instance with SambaNova DeepSeek model
|
# Create LLM instance with SambaNova DeepSeek model
|
||||||
config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key')
|
config = LLMConfig(model='sambanova/DeepSeek-V3-0324', api_key='test_key')
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
# SambaNova DeepSeek model has specific token limits
|
# SambaNova DeepSeek model has specific token limits
|
||||||
# This is the expected value from litellm
|
# This is the expected value from litellm
|
||||||
@@ -1081,7 +1081,7 @@ def test_max_output_tokens_override_in_config():
|
|||||||
config = LLMConfig(
|
config = LLMConfig(
|
||||||
model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048
|
model='claude-sonnet-4-20250514', api_key='test_key', max_output_tokens=2048
|
||||||
)
|
)
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
|
|
||||||
# Verify the config has the overridden max_output_tokens value
|
# Verify the config has the overridden max_output_tokens value
|
||||||
assert llm.config.max_output_tokens == 2048
|
assert llm.config.max_output_tokens == 2048
|
||||||
@@ -1098,7 +1098,7 @@ def test_azure_model_default_max_tokens():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create LLM instance with Azure model
|
# Create LLM instance with Azure model
|
||||||
llm = LLM(azure_config)
|
llm = LLM(azure_config, service_id='test-service')
|
||||||
|
|
||||||
# Verify the config has the default max_output_tokens value
|
# Verify the config has the default max_output_tokens value
|
||||||
assert llm.config.max_output_tokens is None # Default value
|
assert llm.config.max_output_tokens is None # Default value
|
||||||
@@ -1143,7 +1143,7 @@ def test_gemini_none_reasoning_effort_uses_thinking_budget(mock_completion):
|
|||||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||||
llm.completion(messages=sample_messages)
|
llm.completion(messages=sample_messages)
|
||||||
|
|
||||||
@@ -1167,7 +1167,7 @@ def test_gemini_low_reasoning_effort_uses_thinking_budget(mock_completion):
|
|||||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||||
llm.completion(messages=sample_messages)
|
llm.completion(messages=sample_messages)
|
||||||
|
|
||||||
@@ -1191,7 +1191,7 @@ def test_gemini_medium_reasoning_effort_passes_through(mock_completion):
|
|||||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||||
llm.completion(messages=sample_messages)
|
llm.completion(messages=sample_messages)
|
||||||
|
|
||||||
@@ -1214,7 +1214,7 @@ def test_gemini_high_reasoning_effort_passes_through(mock_completion):
|
|||||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||||
llm.completion(messages=sample_messages)
|
llm.completion(messages=sample_messages)
|
||||||
|
|
||||||
@@ -1235,7 +1235,7 @@ def test_non_gemini_uses_reasoning_effort(mock_completion):
|
|||||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||||
llm.completion(messages=sample_messages)
|
llm.completion(messages=sample_messages)
|
||||||
|
|
||||||
@@ -1259,7 +1259,7 @@ def test_non_reasoning_model_no_optimization(mock_completion):
|
|||||||
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
'usage': {'prompt_tokens': 10, 'completion_tokens': 5},
|
||||||
}
|
}
|
||||||
|
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
sample_messages = [{'role': 'user', 'content': 'Hello, how are you?'}]
|
||||||
llm.completion(messages=sample_messages)
|
llm.completion(messages=sample_messages)
|
||||||
|
|
||||||
@@ -1285,7 +1285,7 @@ def test_gemini_performance_optimization_end_to_end(mock_completion):
|
|||||||
assert config.reasoning_effort is None
|
assert config.reasoning_effort is None
|
||||||
|
|
||||||
# Create LLM and make completion
|
# Create LLM and make completion
|
||||||
llm = LLM(config)
|
llm = LLM(config, service_id='test-service')
|
||||||
messages = [{'role': 'user', 'content': 'Solve this complex problem'}]
|
messages = [{'role': 'user', 'content': 'Solve this complex problem'}]
|
||||||
|
|
||||||
response = llm.completion(messages=messages)
|
response = llm.completion(messages=messages)
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
handler = ServiceContextIssue(
|
handler = ServiceContextIssue(
|
||||||
GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
GithubIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||||
)
|
)
|
||||||
@@ -251,7 +251,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLM and handler
|
# Initialize LLM and handler
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
handler = ServiceContextPR(
|
handler = ServiceContextPR(
|
||||||
GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
GithubPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -463,7 +463,7 @@ async def test_process_issue(
|
|||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||||
handler_instance.llm = LLM(llm_config)
|
handler_instance.llm = LLM(llm_config, service_id='test-service')
|
||||||
|
|
||||||
# Mock the runtime and its methods
|
# Mock the runtime and its methods
|
||||||
mock_runtime = MagicMock()
|
mock_runtime = MagicMock()
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ def test_guess_success_rate_limit_wait_time(mock_litellm_completion, default_con
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
handler = ServiceContextIssue(
|
handler = ServiceContextIssue(
|
||||||
GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
GitlabIssueHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||||
)
|
)
|
||||||
@@ -253,7 +253,7 @@ def test_guess_success_exhausts_retries(mock_completion, default_config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLM and handler
|
# Initialize LLM and handler
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
handler = ServiceContextPR(
|
handler = ServiceContextPR(
|
||||||
GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
GitlabPRHandler('test-owner', 'test-repo', 'test-token'), default_config
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -500,7 +500,7 @@ async def test_process_issue(
|
|||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||||
handler_instance.llm = LLM(llm_config)
|
handler_instance.llm = LLM(llm_config, service_id='test-service')
|
||||||
|
|
||||||
# Create mock runtime and mock run_controller
|
# Create mock runtime and mock run_controller
|
||||||
mock_runtime = MagicMock()
|
mock_runtime = MagicMock()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from openhands.controller.state.control_flags import (
|
|||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.core.config import OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
from openhands.core.config.agent_config import AgentConfig
|
from openhands.core.config.agent_config import AgentConfig
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.core.main import run_controller
|
from openhands.core.main import run_controller
|
||||||
from openhands.core.schema import AgentState
|
from openhands.core.schema import AgentState
|
||||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||||
@@ -33,6 +34,7 @@ from openhands.events.observation.agent import RecallObservation
|
|||||||
from openhands.events.observation.empty import NullObservation
|
from openhands.events.observation.empty import NullObservation
|
||||||
from openhands.events.serialization import event_to_dict
|
from openhands.events.serialization import event_to_dict
|
||||||
from openhands.llm import LLM
|
from openhands.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||||
from openhands.llm.metrics import Metrics, TokenUsage
|
from openhands.llm.metrics import Metrics, TokenUsage
|
||||||
from openhands.memory.condenser.condenser import Condensation
|
from openhands.memory.condenser.condenser import Condensation
|
||||||
from openhands.memory.condenser.impl.conversation_window_condenser import (
|
from openhands.memory.condenser.impl.conversation_window_condenser import (
|
||||||
@@ -45,6 +47,7 @@ from openhands.runtime.impl.action_execution.action_execution_client import (
|
|||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
)
|
)
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage.memory import InMemoryFileStore
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
|
|
||||||
|
|
||||||
@@ -61,15 +64,43 @@ def event_loop():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent():
|
def mock_agent_with_stats():
|
||||||
agent = MagicMock(spec=Agent)
|
"""Create a mock agent with properly connected LLM registry and conversation stats."""
|
||||||
agent.llm = MagicMock(spec=LLM)
|
import uuid
|
||||||
agent.llm.metrics = Metrics()
|
|
||||||
agent.llm.config = OpenHandsConfig().get_llm_config()
|
|
||||||
|
|
||||||
# Add config with enable_mcp attribute
|
# Create LLM registry
|
||||||
agent.config = MagicMock(spec=AgentConfig)
|
config = OpenHandsConfig()
|
||||||
agent.config.enable_mcp = True
|
llm_registry = LLMRegistry(config=config)
|
||||||
|
|
||||||
|
# Create conversation stats
|
||||||
|
file_store = InMemoryFileStore({})
|
||||||
|
conversation_id = f'test-conversation-{uuid.uuid4()}'
|
||||||
|
conversation_stats = ConversationStats(
|
||||||
|
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect registry to stats (this is the key requirement)
|
||||||
|
llm_registry.subscribe(conversation_stats.register_llm)
|
||||||
|
|
||||||
|
# Create mock agent
|
||||||
|
agent = MagicMock(spec=Agent)
|
||||||
|
agent_config = MagicMock(spec=AgentConfig)
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
agent_config.disabled_microagents = []
|
||||||
|
agent_config.enable_mcp = True
|
||||||
|
llm_registry.service_to_llm.clear()
|
||||||
|
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||||
|
agent.llm = mock_llm
|
||||||
|
agent.name = 'test-agent'
|
||||||
|
agent.sandbox_plugins = []
|
||||||
|
agent.config = agent_config
|
||||||
|
agent.prompt_manager = MagicMock()
|
||||||
|
|
||||||
# Add a proper system message mock
|
# Add a proper system message mock
|
||||||
system_message = SystemMessageAction(
|
system_message = SystemMessageAction(
|
||||||
@@ -79,7 +110,7 @@ def mock_agent():
|
|||||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||||
agent.get_system_message.return_value = system_message
|
agent.get_system_message.return_value = system_message
|
||||||
|
|
||||||
return agent
|
return agent, conversation_stats, llm_registry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -134,10 +165,13 @@ async def send_event_to_controller(controller, event):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_agent_state(mock_agent, mock_event_stream):
|
async def test_set_agent_state(mock_agent_with_stats, mock_event_stream):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -152,10 +186,13 @@ async def test_set_agent_state(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_event_message_action(mock_agent, mock_event_stream):
|
async def test_on_event_message_action(mock_agent_with_stats, mock_event_stream):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -169,10 +206,15 @@ async def test_on_event_message_action(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
async def test_on_event_change_agent_state_action(
|
||||||
|
mock_agent_with_stats, mock_event_stream
|
||||||
|
):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -186,10 +228,17 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback):
|
async def test_react_to_exception(
|
||||||
|
mock_agent_with_stats,
|
||||||
|
mock_event_stream,
|
||||||
|
mock_status_callback,
|
||||||
|
):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
status_callback=mock_status_callback,
|
status_callback=mock_status_callback,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -204,12 +253,17 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_react_to_content_policy_violation(
|
async def test_react_to_content_policy_violation(
|
||||||
mock_agent, mock_event_stream, mock_status_callback
|
mock_agent_with_stats,
|
||||||
|
mock_event_stream,
|
||||||
|
mock_status_callback,
|
||||||
):
|
):
|
||||||
"""Test that the controller properly handles content policy violations from the LLM."""
|
"""Test that the controller properly handles content policy violations from the LLM."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
status_callback=mock_status_callback,
|
status_callback=mock_status_callback,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -246,18 +300,16 @@ async def test_react_to_content_policy_violation(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_controller_with_fatal_error(
|
async def test_run_controller_with_fatal_error(
|
||||||
test_event_stream, mock_memory, mock_agent
|
test_event_stream, mock_memory, mock_agent_with_stats
|
||||||
):
|
):
|
||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
def agent_step_fn(state):
|
def agent_step_fn(state):
|
||||||
print(f'agent_step_fn received state: {state}')
|
print(f'agent_step_fn received state: {state}')
|
||||||
return CmdRunAction(command='ls')
|
return CmdRunAction(command='ls')
|
||||||
|
|
||||||
mock_agent.step = agent_step_fn
|
mock_agent.step = agent_step_fn
|
||||||
mock_agent.llm = MagicMock(spec=LLM)
|
|
||||||
mock_agent.llm.metrics = Metrics()
|
|
||||||
mock_agent.llm.config = config.get_llm_config()
|
|
||||||
|
|
||||||
runtime = MagicMock(spec=ActionExecutionClient)
|
runtime = MagicMock(spec=ActionExecutionClient)
|
||||||
|
|
||||||
@@ -284,15 +336,17 @@ async def test_run_controller_with_fatal_error(
|
|||||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
state = await run_controller(
|
# Mock the create_agent function to return our mock agent
|
||||||
config=config,
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
initial_user_action=MessageAction(content='Test message'),
|
state = await run_controller(
|
||||||
runtime=runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='Test message'),
|
||||||
agent=mock_agent,
|
runtime=runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=mock_memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
)
|
memory=mock_memory,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
)
|
||||||
print(f'state: {state}')
|
print(f'state: {state}')
|
||||||
events = list(test_event_stream.get_events())
|
events = list(test_event_stream.get_events())
|
||||||
print(f'event_stream: {events}')
|
print(f'event_stream: {events}')
|
||||||
@@ -312,18 +366,16 @@ async def test_run_controller_with_fatal_error(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_controller_stop_with_stuck(
|
async def test_run_controller_stop_with_stuck(
|
||||||
test_event_stream, mock_memory, mock_agent
|
test_event_stream, mock_memory, mock_agent_with_stats
|
||||||
):
|
):
|
||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
def agent_step_fn(state):
|
def agent_step_fn(state):
|
||||||
print(f'agent_step_fn received state: {state}')
|
print(f'agent_step_fn received state: {state}')
|
||||||
return CmdRunAction(command='ls')
|
return CmdRunAction(command='ls')
|
||||||
|
|
||||||
mock_agent.step = agent_step_fn
|
mock_agent.step = agent_step_fn
|
||||||
mock_agent.llm = MagicMock(spec=LLM)
|
|
||||||
mock_agent.llm.metrics = Metrics()
|
|
||||||
mock_agent.llm.config = config.get_llm_config()
|
|
||||||
|
|
||||||
runtime = MagicMock(spec=ActionExecutionClient)
|
runtime = MagicMock(spec=ActionExecutionClient)
|
||||||
|
|
||||||
@@ -352,15 +404,17 @@ async def test_run_controller_stop_with_stuck(
|
|||||||
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
state = await run_controller(
|
# Mock the create_agent function to return our mock agent
|
||||||
config=config,
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
initial_user_action=MessageAction(content='Test message'),
|
state = await run_controller(
|
||||||
runtime=runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='Test message'),
|
||||||
agent=mock_agent,
|
runtime=runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=mock_memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
)
|
memory=mock_memory,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
)
|
||||||
events = list(test_event_stream.get_events())
|
events = list(test_event_stream.get_events())
|
||||||
print(f'state: {state}')
|
print(f'state: {state}')
|
||||||
for i, event in enumerate(events):
|
for i, event in enumerate(events):
|
||||||
@@ -391,11 +445,14 @@ async def test_run_controller_stop_with_stuck(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
async def test_max_iterations_extension(mock_agent_with_stats, mock_event_stream):
|
||||||
# Test with headless_mode=False - should extend max_iterations
|
# Test with headless_mode=False - should extend max_iterations
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -426,6 +483,7 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
|||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -450,7 +508,9 @@ async def test_max_iterations_extension(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_step_max_budget(mock_agent, mock_event_stream):
|
async def test_step_max_budget(mock_agent_with_stats, mock_event_stream):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
# Metrics are always synced with budget flag before
|
# Metrics are always synced with budget flag before
|
||||||
metrics = Metrics()
|
metrics = Metrics()
|
||||||
metrics.accumulated_cost = 10.1
|
metrics.accumulated_cost = 10.1
|
||||||
@@ -458,9 +518,13 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
|||||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update agent's LLM metrics in place
|
||||||
|
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
budget_per_task_delta=10,
|
budget_per_task_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -475,7 +539,9 @@ async def test_step_max_budget(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
async def test_step_max_budget_headless(mock_agent_with_stats, mock_event_stream):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
# Metrics are always synced with budget flag before
|
# Metrics are always synced with budget flag before
|
||||||
metrics = Metrics()
|
metrics = Metrics()
|
||||||
metrics.accumulated_cost = 10.1
|
metrics.accumulated_cost = 10.1
|
||||||
@@ -483,9 +549,13 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
|||||||
limit_increase_amount=10, current_value=10.1, max_value=10
|
limit_increase_amount=10, current_value=10.1, max_value=10
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update agent's LLM metrics in place
|
||||||
|
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
budget_per_task_delta=10,
|
budget_per_task_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -500,12 +570,14 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
async def test_budget_reset_on_continue(mock_agent_with_stats, mock_event_stream):
|
||||||
"""Test that when a user continues after hitting the budget limit:
|
"""Test that when a user continues after hitting the budget limit:
|
||||||
1. Error is thrown when budget cap is exceeded
|
1. Error is thrown when budget cap is exceeded
|
||||||
2. LLM budget does not reset when user continues
|
2. LLM budget does not reset when user continues
|
||||||
3. Budget is extended by adding the initial budget cap to the current accumulated cost
|
3. Budget is extended by adding the initial budget cap to the current accumulated cost
|
||||||
"""
|
"""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
# Create a real Metrics instance shared between controller state and llm
|
# Create a real Metrics instance shared between controller state and llm
|
||||||
metrics = Metrics()
|
metrics = Metrics()
|
||||||
metrics.accumulated_cost = 6.0
|
metrics.accumulated_cost = 6.0
|
||||||
@@ -521,10 +593,14 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update agent's LLM metrics in place
|
||||||
|
mock_agent.llm.metrics.accumulated_cost = metrics.accumulated_cost
|
||||||
|
|
||||||
# Create controller with budget cap
|
# Create controller with budget cap
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
budget_per_task_delta=initial_budget,
|
budget_per_task_delta=initial_budget,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -570,11 +646,17 @@ async def test_budget_reset_on_continue(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
|
async def test_reset_with_pending_action_no_observation(
|
||||||
|
mock_agent_with_stats, mock_event_stream
|
||||||
|
):
|
||||||
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
"""Test reset() when there's a pending action with tool call metadata but no observation."""
|
||||||
|
# Connect LLM registry to conversation stats
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -617,11 +699,17 @@ async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_s
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_stream):
|
async def test_reset_with_pending_action_stopped_state(
|
||||||
|
mock_agent_with_stats, mock_event_stream
|
||||||
|
):
|
||||||
"""Test reset() when there's a pending action and agent state is STOPPED."""
|
"""Test reset() when there's a pending action and agent state is STOPPED."""
|
||||||
|
# Connect LLM registry to conversation stats
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -665,12 +753,16 @@ async def test_reset_with_pending_action_stopped_state(mock_agent, mock_event_st
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_with_pending_action_existing_observation(
|
async def test_reset_with_pending_action_existing_observation(
|
||||||
mock_agent, mock_event_stream
|
mock_agent_with_stats, mock_event_stream
|
||||||
):
|
):
|
||||||
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
|
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
|
||||||
|
# Connect LLM registry to conversation stats
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -708,11 +800,15 @@ async def test_reset_with_pending_action_existing_observation(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
async def test_reset_without_pending_action(mock_agent_with_stats, mock_event_stream):
|
||||||
"""Test reset() when there's no pending action."""
|
"""Test reset() when there's no pending action."""
|
||||||
|
# Connect LLM registry to conversation stats
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -738,12 +834,15 @@ async def test_reset_without_pending_action(mock_agent, mock_event_stream):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reset_with_pending_action_no_metadata(
|
async def test_reset_with_pending_action_no_metadata(
|
||||||
mock_agent, mock_event_stream, monkeypatch
|
mock_agent_with_stats, mock_event_stream, monkeypatch
|
||||||
):
|
):
|
||||||
"""Test reset() when there's a pending action without tool call metadata."""
|
"""Test reset() when there's a pending action without tool call metadata."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -782,16 +881,13 @@ async def test_reset_with_pending_action_no_metadata(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_controller_max_iterations_has_metrics(
|
async def test_run_controller_max_iterations_has_metrics(
|
||||||
test_event_stream, mock_memory, mock_agent
|
test_event_stream, mock_memory, mock_agent_with_stats
|
||||||
):
|
):
|
||||||
config = OpenHandsConfig(
|
config = OpenHandsConfig(
|
||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
)
|
)
|
||||||
event_stream = test_event_stream
|
event_stream = test_event_stream
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
mock_agent.llm = MagicMock(spec=LLM)
|
|
||||||
mock_agent.llm.metrics = Metrics()
|
|
||||||
mock_agent.llm.config = config.get_llm_config()
|
|
||||||
|
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
|
||||||
@@ -833,15 +929,17 @@ async def test_run_controller_max_iterations_has_metrics(
|
|||||||
|
|
||||||
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
|
event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()))
|
||||||
|
|
||||||
state = await run_controller(
|
# Mock the create_agent function to return our mock agent
|
||||||
config=config,
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
initial_user_action=MessageAction(content='Test message'),
|
state = await run_controller(
|
||||||
runtime=runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='Test message'),
|
||||||
agent=mock_agent,
|
runtime=runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=mock_memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
)
|
memory=mock_memory,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
)
|
||||||
|
|
||||||
state.metrics = mock_agent.llm.metrics
|
state.metrics = mock_agent.llm.metrics
|
||||||
assert state.iteration_flag.current_value == 3
|
assert state.iteration_flag.current_value == 3
|
||||||
@@ -867,10 +965,17 @@ async def test_run_controller_max_iterations_has_metrics(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
|
async def test_notify_on_llm_retry(
|
||||||
|
mock_agent_with_stats,
|
||||||
|
mock_event_stream,
|
||||||
|
mock_status_callback,
|
||||||
|
):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
status_callback=mock_status_callback,
|
status_callback=mock_status_callback,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -908,9 +1013,15 @@ async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_ca
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_context_window_exceeded_error_handling(
|
async def test_context_window_exceeded_error_handling(
|
||||||
context_window_error, mock_agent, mock_runtime, test_event_stream, mock_memory
|
context_window_error,
|
||||||
|
mock_agent_with_stats,
|
||||||
|
mock_runtime,
|
||||||
|
test_event_stream,
|
||||||
|
mock_memory,
|
||||||
):
|
):
|
||||||
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
|
"""Test that context window exceeded errors are handled correctly by the controller, providing a smaller view but keeping the history intact."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
max_iterations = 5
|
max_iterations = 5
|
||||||
error_after = 2
|
error_after = 2
|
||||||
|
|
||||||
@@ -973,18 +1084,20 @@ async def test_context_window_exceeded_error_handling(
|
|||||||
# state is set to error out before then, if this terminates and we have a
|
# state is set to error out before then, if this terminates and we have a
|
||||||
# record of the error being thrown we can be confident that the controller
|
# record of the error being thrown we can be confident that the controller
|
||||||
# handles the truncation correctly.
|
# handles the truncation correctly.
|
||||||
final_state = await asyncio.wait_for(
|
# Mock the create_agent function to return our mock agent
|
||||||
run_controller(
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
config=config,
|
final_state = await asyncio.wait_for(
|
||||||
initial_user_action=MessageAction(content='INITIAL'),
|
run_controller(
|
||||||
runtime=mock_runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='INITIAL'),
|
||||||
agent=mock_agent,
|
runtime=mock_runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=mock_memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
),
|
memory=mock_memory,
|
||||||
timeout=10,
|
llm_registry=llm_registry,
|
||||||
)
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
# Check that the context window exception was thrown and the controller
|
# Check that the context window exception was thrown and the controller
|
||||||
# called the agent's `step` function the right number of times.
|
# called the agent's `step` function the right number of times.
|
||||||
@@ -1072,9 +1185,13 @@ async def test_context_window_exceeded_error_handling(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_controller_with_context_window_exceeded_with_truncation(
|
async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
mock_agent_with_stats,
|
||||||
|
mock_runtime,
|
||||||
|
mock_memory,
|
||||||
|
test_event_stream,
|
||||||
):
|
):
|
||||||
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON."""
|
"""Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
class StepState:
|
class StepState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -1121,18 +1238,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
|||||||
mock_runtime.config = copy.deepcopy(config)
|
mock_runtime.config = copy.deepcopy(config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = await asyncio.wait_for(
|
# Mock the create_agent function to return our mock agent
|
||||||
run_controller(
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
config=config,
|
state = await asyncio.wait_for(
|
||||||
initial_user_action=MessageAction(content='INITIAL'),
|
run_controller(
|
||||||
runtime=mock_runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='INITIAL'),
|
||||||
agent=mock_agent,
|
runtime=mock_runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=mock_memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
),
|
memory=mock_memory,
|
||||||
timeout=10,
|
llm_registry=llm_registry,
|
||||||
)
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
# A timeout error indicates the run_controller entrypoint is not making
|
# A timeout error indicates the run_controller entrypoint is not making
|
||||||
# progress
|
# progress
|
||||||
@@ -1156,9 +1275,13 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_controller_with_context_window_exceeded_without_truncation(
|
async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||||
mock_agent, mock_runtime, mock_memory, test_event_stream
|
mock_agent_with_stats,
|
||||||
|
mock_runtime,
|
||||||
|
mock_memory,
|
||||||
|
test_event_stream,
|
||||||
):
|
):
|
||||||
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
|
"""Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
class StepState:
|
class StepState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -1199,18 +1322,20 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
|||||||
config = OpenHandsConfig(max_iterations=3)
|
config = OpenHandsConfig(max_iterations=3)
|
||||||
mock_runtime.config = copy.deepcopy(config)
|
mock_runtime.config = copy.deepcopy(config)
|
||||||
try:
|
try:
|
||||||
state = await asyncio.wait_for(
|
# Mock the create_agent function to return our mock agent
|
||||||
run_controller(
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
config=config,
|
state = await asyncio.wait_for(
|
||||||
initial_user_action=MessageAction(content='INITIAL'),
|
run_controller(
|
||||||
runtime=mock_runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='INITIAL'),
|
||||||
agent=mock_agent,
|
runtime=mock_runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=mock_memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
),
|
memory=mock_memory,
|
||||||
timeout=10,
|
llm_registry=llm_registry,
|
||||||
)
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
# A timeout error indicates the run_controller entrypoint is not making
|
# A timeout error indicates the run_controller entrypoint is not making
|
||||||
# progress
|
# progress
|
||||||
@@ -1244,7 +1369,11 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
async def test_run_controller_with_memory_error(
|
||||||
|
test_event_stream, mock_agent_with_stats
|
||||||
|
):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
event_stream = test_event_stream
|
event_stream = test_event_stream
|
||||||
|
|
||||||
@@ -1273,15 +1402,17 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
|
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
|
||||||
):
|
):
|
||||||
state = await run_controller(
|
# Mock the create_agent function to return our mock agent
|
||||||
config=config,
|
with patch('openhands.core.main.create_agent', return_value=mock_agent):
|
||||||
initial_user_action=MessageAction(content='Test message'),
|
state = await run_controller(
|
||||||
runtime=runtime,
|
config=config,
|
||||||
sid='test',
|
initial_user_action=MessageAction(content='Test message'),
|
||||||
agent=mock_agent,
|
runtime=runtime,
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
sid='test',
|
||||||
memory=memory,
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
)
|
memory=memory,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
)
|
||||||
|
|
||||||
assert state.iteration_flag.current_value == 0
|
assert state.iteration_flag.current_value == 0
|
||||||
assert state.agent_state == AgentState.ERROR
|
assert state.agent_state == AgentState.ERROR
|
||||||
@@ -1289,7 +1420,9 @@ async def test_run_controller_with_memory_error(test_event_stream, mock_agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_action_metrics_copy(mock_agent):
|
async def test_action_metrics_copy(mock_agent_with_stats):
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
file_store = InMemoryFileStore({})
|
file_store = InMemoryFileStore({})
|
||||||
event_stream = EventStream(sid='test', file_store=file_store)
|
event_stream = EventStream(sid='test', file_store=file_store)
|
||||||
@@ -1299,8 +1432,7 @@ async def test_action_metrics_copy(mock_agent):
|
|||||||
|
|
||||||
initial_state = State(metrics=metrics, budget_flag=None)
|
initial_state = State(metrics=metrics, budget_flag=None)
|
||||||
|
|
||||||
# Create agent with metrics
|
# Update agent's LLM metrics
|
||||||
mock_agent.llm = MagicMock(spec=LLM)
|
|
||||||
|
|
||||||
# Add multiple token usages - we should get the last one in the action
|
# Add multiple token usages - we should get the last one in the action
|
||||||
usage1 = TokenUsage(
|
usage1 = TokenUsage(
|
||||||
@@ -1342,6 +1474,11 @@ async def test_action_metrics_copy(mock_agent):
|
|||||||
|
|
||||||
mock_agent.llm.metrics = metrics
|
mock_agent.llm.metrics = metrics
|
||||||
|
|
||||||
|
# Register the metrics with the LLM registry
|
||||||
|
llm_registry.service_to_llm['agent'] = mock_agent.llm
|
||||||
|
# Manually notify the conversation stats about the LLM registration
|
||||||
|
llm_registry.notify(RegistryEvent(llm=mock_agent.llm, service_id='agent'))
|
||||||
|
|
||||||
# Mock agent step to return an action
|
# Mock agent step to return an action
|
||||||
action = MessageAction(content='Test message')
|
action = MessageAction(content='Test message')
|
||||||
|
|
||||||
@@ -1354,6 +1491,7 @@ async def test_action_metrics_copy(mock_agent):
|
|||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -1411,12 +1549,13 @@ async def test_action_metrics_copy(mock_agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
async def test_condenser_metrics_included(mock_agent_with_stats, test_event_stream):
|
||||||
"""Test that metrics from the condenser's LLM are included in the action metrics."""
|
"""Test that metrics from the condenser's LLM are included in the action metrics."""
|
||||||
# Set up agent metrics
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
agent_metrics = Metrics(model_name='agent-model')
|
|
||||||
agent_metrics.accumulated_cost = 0.05
|
# Set up agent metrics in place
|
||||||
agent_metrics._accumulated_token_usage = TokenUsage(
|
mock_agent.llm.metrics.accumulated_cost = 0.05
|
||||||
|
mock_agent.llm.metrics._accumulated_token_usage = TokenUsage(
|
||||||
model='agent-model',
|
model='agent-model',
|
||||||
prompt_tokens=100,
|
prompt_tokens=100,
|
||||||
completion_tokens=50,
|
completion_tokens=50,
|
||||||
@@ -1424,7 +1563,6 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
|||||||
cache_write_tokens=10,
|
cache_write_tokens=10,
|
||||||
response_id='agent-accumulated',
|
response_id='agent-accumulated',
|
||||||
)
|
)
|
||||||
# mock_agent.llm.metrics = agent_metrics
|
|
||||||
mock_agent.name = 'TestAgent'
|
mock_agent.name = 'TestAgent'
|
||||||
|
|
||||||
# Create condenser with its own metrics
|
# Create condenser with its own metrics
|
||||||
@@ -1442,6 +1580,11 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
|||||||
)
|
)
|
||||||
condenser.llm.metrics = condenser_metrics
|
condenser.llm.metrics = condenser_metrics
|
||||||
|
|
||||||
|
# Register the condenser metrics with the LLM registry
|
||||||
|
llm_registry.service_to_llm['condenser'] = condenser.llm
|
||||||
|
# Manually notify the conversation stats about the condenser LLM registration
|
||||||
|
llm_registry.notify(RegistryEvent(llm=condenser.llm, service_id='condenser'))
|
||||||
|
|
||||||
# Attach the condenser to the mock_agent
|
# Attach the condenser to the mock_agent
|
||||||
mock_agent.condenser = condenser
|
mock_agent.condenser = condenser
|
||||||
|
|
||||||
@@ -1463,11 +1606,12 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
|||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=test_event_stream,
|
event_stream=test_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
headless_mode=True,
|
headless_mode=True,
|
||||||
initial_state=State(metrics=agent_metrics, budget_flag=None),
|
initial_state=State(metrics=mock_agent.llm.metrics, budget_flag=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute one step
|
# Execute one step
|
||||||
@@ -1505,7 +1649,9 @@ async def test_condenser_metrics_included(mock_agent, test_event_stream):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_first_user_message_with_identical_content(test_event_stream, mock_agent):
|
async def test_first_user_message_with_identical_content(
|
||||||
|
test_event_stream, mock_agent_with_stats
|
||||||
|
):
|
||||||
"""Test that _first_user_message correctly identifies the first user message.
|
"""Test that _first_user_message correctly identifies the first user message.
|
||||||
|
|
||||||
This test verifies that messages with identical content but different IDs are properly
|
This test verifies that messages with identical content but different IDs are properly
|
||||||
@@ -1514,14 +1660,12 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
|||||||
The issue we're checking is that the comparison (action == self._first_user_message())
|
The issue we're checking is that the comparison (action == self._first_user_message())
|
||||||
should correctly differentiate between messages with the same content but different IDs.
|
should correctly differentiate between messages with the same content but different IDs.
|
||||||
"""
|
"""
|
||||||
# Create an agent controller
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
mock_agent.llm = MagicMock(spec=LLM)
|
|
||||||
mock_agent.llm.metrics = Metrics()
|
|
||||||
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
|
|
||||||
|
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=test_event_stream,
|
event_stream=test_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -1569,11 +1713,15 @@ async def test_first_user_message_with_identical_content(test_event_stream, mock
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_controller_processes_null_observation_with_cause():
|
async def test_agent_controller_processes_null_observation_with_cause(
|
||||||
|
mock_agent_with_stats,
|
||||||
|
):
|
||||||
"""Test that AgentController processes NullObservation events with a cause value.
|
"""Test that AgentController processes NullObservation events with a cause value.
|
||||||
|
|
||||||
And that the agent's step method is called as a result.
|
And that the agent's step method is called as a result.
|
||||||
"""
|
"""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
# Create an in-memory file store and real event stream
|
# Create an in-memory file store and real event stream
|
||||||
file_store = InMemoryFileStore()
|
file_store = InMemoryFileStore()
|
||||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||||
@@ -1581,19 +1729,11 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
|||||||
# Create a Memory instance - not used directly in this test but needed for setup
|
# Create a Memory instance - not used directly in this test but needed for setup
|
||||||
Memory(event_stream=event_stream, sid='test-session')
|
Memory(event_stream=event_stream, sid='test-session')
|
||||||
|
|
||||||
# Create a mock agent with necessary attributes
|
|
||||||
mock_agent = MagicMock(spec=Agent)
|
|
||||||
mock_agent.get_system_message = MagicMock(
|
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
mock_agent.llm = MagicMock(spec=LLM)
|
|
||||||
mock_agent.llm.metrics = Metrics()
|
|
||||||
mock_agent.llm.config = OpenHandsConfig().get_llm_config()
|
|
||||||
|
|
||||||
# Create a controller with the mock agent
|
# Create a controller with the mock agent
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
)
|
)
|
||||||
@@ -1655,8 +1795,12 @@ async def test_agent_controller_processes_null_observation_with_cause():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agent):
|
def test_agent_controller_should_step_with_null_observation_cause_zero(
|
||||||
|
mock_agent_with_stats,
|
||||||
|
):
|
||||||
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
|
"""Test that AgentController's should_step method returns False for NullObservation with cause = 0."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
# Create a mock event stream
|
# Create a mock event stream
|
||||||
file_store = InMemoryFileStore()
|
file_store = InMemoryFileStore()
|
||||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||||
@@ -1665,6 +1809,7 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
|||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_agent,
|
agent=mock_agent,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
)
|
)
|
||||||
@@ -1683,10 +1828,15 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
def test_system_message_in_event_stream(mock_agent_with_stats, test_event_stream):
|
||||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||||
|
mock_agent, conversation_stats, llm_registry = mock_agent_with_stats
|
||||||
|
|
||||||
_ = AgentController(
|
_ = AgentController(
|
||||||
agent=mock_agent, event_stream=test_event_stream, iteration_delta=10
|
agent=mock_agent,
|
||||||
|
event_stream=test_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
|
iteration_delta=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get events from the event stream
|
# Get events from the event stream
|
||||||
|
|||||||
@@ -12,8 +12,9 @@ from openhands.controller.state.control_flags import (
|
|||||||
IterationControlFlag,
|
IterationControlFlag,
|
||||||
)
|
)
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.core.config import LLMConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
from openhands.core.config.agent_config import AgentConfig
|
from openhands.core.config.agent_config import AgentConfig
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.core.schema import AgentState
|
from openhands.core.schema import AgentState
|
||||||
from openhands.events import EventSource, EventStream
|
from openhands.events import EventSource, EventStream
|
||||||
from openhands.events.action import (
|
from openhands.events.action import (
|
||||||
@@ -28,11 +29,39 @@ from openhands.events.event import Event, RecallType
|
|||||||
from openhands.events.observation.agent import RecallObservation
|
from openhands.events.observation.agent import RecallObservation
|
||||||
from openhands.events.stream import EventStreamSubscriber
|
from openhands.events.stream import EventStreamSubscriber
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.llm.metrics import Metrics
|
||||||
from openhands.memory.memory import Memory
|
from openhands.memory.memory import Memory
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.storage.memory import InMemoryFileStore
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_registry():
|
||||||
|
config = OpenHandsConfig()
|
||||||
|
return LLMRegistry(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_stats():
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
file_store = InMemoryFileStore({})
|
||||||
|
# Use a unique conversation ID for each test to avoid conflicts
|
||||||
|
conversation_id = f'test-conversation-{uuid.uuid4()}'
|
||||||
|
return ConversationStats(
|
||||||
|
file_store=file_store, conversation_id=conversation_id, user_id='test-user'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def connected_registry_and_stats(llm_registry, conversation_stats):
|
||||||
|
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||||
|
# Subscribe to LLM registry events to track metrics
|
||||||
|
llm_registry.subscribe(conversation_stats.register_llm)
|
||||||
|
return llm_registry, conversation_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_event_stream():
|
def mock_event_stream():
|
||||||
"""Creates an event stream in memory."""
|
"""Creates an event stream in memory."""
|
||||||
@@ -42,15 +71,17 @@ def mock_event_stream():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_parent_agent():
|
def mock_parent_agent(llm_registry):
|
||||||
"""Creates a mock parent agent for testing delegation."""
|
"""Creates a mock parent agent for testing delegation."""
|
||||||
agent = MagicMock(spec=Agent)
|
agent = MagicMock(spec=Agent)
|
||||||
agent.name = 'ParentAgent'
|
agent.name = 'ParentAgent'
|
||||||
agent.llm = MagicMock(spec=LLM)
|
agent.llm = MagicMock(spec=LLM)
|
||||||
|
agent.llm.service_id = 'main_agent'
|
||||||
agent.llm.metrics = Metrics()
|
agent.llm.metrics = Metrics()
|
||||||
agent.llm.config = LLMConfig()
|
agent.llm.config = LLMConfig()
|
||||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||||
agent.config = AgentConfig()
|
agent.config = AgentConfig()
|
||||||
|
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
|
||||||
|
|
||||||
# Add a proper system message mock
|
# Add a proper system message mock
|
||||||
system_message = SystemMessageAction(content='Test system message')
|
system_message = SystemMessageAction(content='Test system message')
|
||||||
@@ -61,15 +92,17 @@ def mock_parent_agent():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_child_agent():
|
def mock_child_agent(llm_registry):
|
||||||
"""Creates a mock child agent for testing delegation."""
|
"""Creates a mock child agent for testing delegation."""
|
||||||
agent = MagicMock(spec=Agent)
|
agent = MagicMock(spec=Agent)
|
||||||
agent.name = 'ChildAgent'
|
agent.name = 'ChildAgent'
|
||||||
agent.llm = MagicMock(spec=LLM)
|
agent.llm = MagicMock(spec=LLM)
|
||||||
|
agent.llm.service_id = 'main_agent'
|
||||||
agent.llm.metrics = Metrics()
|
agent.llm.metrics = Metrics()
|
||||||
agent.llm.config = LLMConfig()
|
agent.llm.config = LLMConfig()
|
||||||
agent.llm.retry_listener = None # Add retry_listener attribute
|
agent.llm.retry_listener = None # Add retry_listener attribute
|
||||||
agent.config = AgentConfig()
|
agent.config = AgentConfig()
|
||||||
|
agent.llm_registry = llm_registry # Add the missing llm_registry attribute
|
||||||
|
|
||||||
system_message = SystemMessageAction(content='Test system message')
|
system_message = SystemMessageAction(content='Test system message')
|
||||||
system_message._source = EventSource.AGENT
|
system_message._source = EventSource.AGENT
|
||||||
@@ -78,15 +111,37 @@ def mock_child_agent():
|
|||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_agent_factory(mock_child_agent, llm_registry):
|
||||||
|
"""Helper function to create a mock agent factory with proper LLM registration."""
|
||||||
|
|
||||||
|
def create_mock_agent(config, llm_registry=None):
|
||||||
|
# Register the mock agent's LLM in the registry so get_combined_metrics() can find it
|
||||||
|
if llm_registry:
|
||||||
|
mock_child_agent.llm = llm_registry.get_llm('agent_llm', LLMConfig())
|
||||||
|
mock_child_agent.llm_registry = (
|
||||||
|
llm_registry # Set the llm_registry attribute
|
||||||
|
)
|
||||||
|
return mock_child_agent
|
||||||
|
|
||||||
|
return create_mock_agent
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream):
|
async def test_delegation_flow(
|
||||||
"""Test that when the parent agent delegates to a child
|
mock_parent_agent, mock_child_agent, mock_event_stream, connected_registry_and_stats
|
||||||
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
):
|
||||||
2. metrics are accumulated globally (delegate is adding to the parents metrics)
|
|
||||||
3. local metrics for the delegate are still accessible
|
|
||||||
"""
|
"""
|
||||||
|
Test that when the parent agent delegates to a child
|
||||||
|
1. the parent's delegate is set, and once the child finishes, the parent is cleaned up properly.
|
||||||
|
2. metrics are accumulated globally via LLM registry (delegate adds to the global metrics)
|
||||||
|
3. global metrics tracking works correctly through the LLM registry
|
||||||
|
"""
|
||||||
|
llm_registry, conversation_stats = connected_registry_and_stats
|
||||||
|
|
||||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
Agent.get_cls = Mock(
|
||||||
|
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
|
||||||
|
)
|
||||||
|
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
|
||||||
@@ -97,6 +152,12 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
|||||||
|
|
||||||
mock_child_agent.step = agent_step_fn
|
mock_child_agent.step = agent_step_fn
|
||||||
|
|
||||||
|
# Set up the parent agent's LLM with initial cost and register it in the registry
|
||||||
|
# The parent agent's LLM should use the existing registered LLM to ensure proper tracking
|
||||||
|
parent_llm = llm_registry.service_to_llm['agent']
|
||||||
|
parent_llm.metrics.accumulated_cost = 2
|
||||||
|
mock_parent_agent.llm = parent_llm
|
||||||
|
|
||||||
parent_metrics = Metrics()
|
parent_metrics = Metrics()
|
||||||
parent_metrics.accumulated_cost = 2
|
parent_metrics.accumulated_cost = 2
|
||||||
# Create parent controller
|
# Create parent controller
|
||||||
@@ -114,6 +175,7 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
|||||||
parent_controller = AgentController(
|
parent_controller = AgentController(
|
||||||
agent=mock_parent_agent,
|
agent=mock_parent_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=1, # Add the required iteration_delta parameter
|
iteration_delta=1, # Add the required iteration_delta parameter
|
||||||
sid='parent',
|
sid='parent',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -180,21 +242,23 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
|||||||
for i in range(4):
|
for i in range(4):
|
||||||
delegate_controller.state.iteration_flag.step()
|
delegate_controller.state.iteration_flag.step()
|
||||||
delegate_controller.agent.step(delegate_controller.state)
|
delegate_controller.agent.step(delegate_controller.state)
|
||||||
|
# Update the agent's LLM metrics (not the deprecated state metrics)
|
||||||
delegate_controller.agent.llm.metrics.add_cost(1.0)
|
delegate_controller.agent.llm.metrics.add_cost(1.0)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
delegate_controller.state.get_local_step() == 4
|
delegate_controller.state.get_local_step() == 4
|
||||||
) # verify local metrics are accessible via snapshot
|
) # verify local metrics are accessible via snapshot
|
||||||
|
|
||||||
|
# Check that the conversation stats has the combined metrics (parent + delegate)
|
||||||
|
combined_metrics = delegate_controller.state.convo_stats.get_combined_metrics()
|
||||||
assert (
|
assert (
|
||||||
delegate_controller.state.metrics.accumulated_cost
|
combined_metrics.accumulated_cost
|
||||||
== 6 # Make sure delegate tracks global cost
|
== 6 # Make sure delegate tracks global cost (2 from parent + 4 from delegate)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
# Since metrics are now global via LLM registry, local metrics tracking
|
||||||
delegate_controller.state.get_local_metrics().accumulated_cost
|
# is handled differently. The delegate's LLM shares the same metrics object
|
||||||
== 4 # Delegate spent one dollar per step
|
# as the parent for global tracking, so we verify the global total is correct.
|
||||||
)
|
|
||||||
|
|
||||||
delegate_controller.state.outputs = {'delegate_result': 'done'}
|
delegate_controller.state.outputs = {'delegate_result': 'done'}
|
||||||
|
|
||||||
@@ -228,15 +292,18 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_delegate_step_different_states(
|
async def test_delegate_step_different_states(
|
||||||
mock_parent_agent, mock_event_stream, delegate_state
|
mock_parent_agent, mock_event_stream, delegate_state, connected_registry_and_stats
|
||||||
):
|
):
|
||||||
"""Ensure that delegate is closed or remains open based on the delegate's state."""
|
"""Ensure that delegate is closed or remains open based on the delegate's state."""
|
||||||
|
llm_registry, conversation_stats = connected_registry_and_stats
|
||||||
|
|
||||||
# Create a state with iteration_flag.max_value set to 10
|
# Create a state with iteration_flag.max_value set to 10
|
||||||
state = State(inputs={})
|
state = State(inputs={})
|
||||||
state.iteration_flag.max_value = 10
|
state.iteration_flag.max_value = 10
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
agent=mock_parent_agent,
|
agent=mock_parent_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=1, # Add the required iteration_delta parameter
|
iteration_delta=1, # Add the required iteration_delta parameter
|
||||||
sid='test',
|
sid='test',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
@@ -292,11 +359,23 @@ async def test_delegate_step_different_states(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delegate_hits_global_limits(
|
async def test_delegate_hits_global_limits(
|
||||||
mock_child_agent, mock_event_stream, mock_parent_agent
|
mock_child_agent, mock_event_stream, mock_parent_agent, connected_registry_and_stats
|
||||||
):
|
):
|
||||||
"""Global limits from control flags should apply to delegates"""
|
"""
|
||||||
|
Global limits from control flags should apply to delegates
|
||||||
|
"""
|
||||||
|
llm_registry, conversation_stats = connected_registry_and_stats
|
||||||
|
|
||||||
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
# Mock the agent class resolution so that AgentController can instantiate mock_child_agent
|
||||||
Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent)
|
Agent.get_cls = Mock(
|
||||||
|
return_value=create_mock_agent_factory(mock_child_agent, llm_registry)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up the parent agent's LLM with initial cost and register it in the registry
|
||||||
|
mock_parent_agent.llm.metrics.accumulated_cost = 2
|
||||||
|
mock_parent_agent.llm.service_id = 'main_agent'
|
||||||
|
# Register the parent agent's LLM in the registry
|
||||||
|
llm_registry.service_to_llm['main_agent'] = mock_parent_agent.llm
|
||||||
|
|
||||||
parent_metrics = Metrics()
|
parent_metrics = Metrics()
|
||||||
parent_metrics.accumulated_cost = 2
|
parent_metrics.accumulated_cost = 2
|
||||||
@@ -315,6 +394,7 @@ async def test_delegate_hits_global_limits(
|
|||||||
parent_controller = AgentController(
|
parent_controller = AgentController(
|
||||||
agent=mock_parent_agent,
|
agent=mock_parent_agent,
|
||||||
event_stream=mock_event_stream,
|
event_stream=mock_event_stream,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
iteration_delta=1, # Add the required iteration_delta parameter
|
iteration_delta=1, # Add the required iteration_delta parameter
|
||||||
sid='parent',
|
sid='parent',
|
||||||
confirmation_mode=False,
|
confirmation_mode=False,
|
||||||
|
|||||||
@@ -9,12 +9,13 @@ from openhands.core.config import LLMConfig, OpenHandsConfig
|
|||||||
from openhands.core.config.agent_config import AgentConfig
|
from openhands.core.config.agent_config import AgentConfig
|
||||||
from openhands.events import EventStream, EventStreamSubscriber
|
from openhands.events import EventStream, EventStreamSubscriber
|
||||||
from openhands.integrations.service_types import ProviderType
|
from openhands.integrations.service_types import ProviderType
|
||||||
from openhands.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.llm.metrics import Metrics
|
||||||
from openhands.memory.memory import Memory
|
from openhands.memory.memory import Memory
|
||||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
)
|
)
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.server.session.agent_session import AgentSession
|
from openhands.server.session.agent_session import AgentSession
|
||||||
from openhands.storage.memory import InMemoryFileStore
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
|
|
||||||
@@ -22,44 +23,70 @@ from openhands.storage.memory import InMemoryFileStore
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent():
|
def mock_llm_registry():
|
||||||
"""Create a properly configured mock agent with all required nested attributes"""
|
"""Create a mock LLM registry that properly simulates LLM registration"""
|
||||||
# Create the base mocks
|
config = OpenHandsConfig()
|
||||||
agent = MagicMock(spec=Agent)
|
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||||
llm = MagicMock(spec=LLM)
|
return registry
|
||||||
metrics = MagicMock(spec=Metrics)
|
|
||||||
llm_config = MagicMock(spec=LLMConfig)
|
|
||||||
agent_config = MagicMock(spec=AgentConfig)
|
|
||||||
|
|
||||||
# Configure the LLM config
|
|
||||||
llm_config.model = 'test-model'
|
|
||||||
llm_config.base_url = 'http://test'
|
|
||||||
llm_config.max_message_chars = 1000
|
|
||||||
|
|
||||||
# Configure the agent config
|
@pytest.fixture
|
||||||
agent_config.disabled_microagents = []
|
def mock_conversation_stats():
|
||||||
agent_config.enable_mcp = True
|
"""Create a mock ConversationStats that properly simulates metrics tracking"""
|
||||||
|
file_store = InMemoryFileStore({})
|
||||||
|
stats = ConversationStats(
|
||||||
|
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||||
|
)
|
||||||
|
return stats
|
||||||
|
|
||||||
# Set up the chain of mocks
|
|
||||||
llm.metrics = metrics
|
|
||||||
llm.config = llm_config
|
|
||||||
agent.llm = llm
|
|
||||||
agent.name = 'test-agent'
|
|
||||||
agent.sandbox_plugins = []
|
|
||||||
agent.config = agent_config
|
|
||||||
agent.prompt_manager = MagicMock()
|
|
||||||
|
|
||||||
return agent
|
@pytest.fixture
|
||||||
|
def connected_registry_and_stats(mock_llm_registry, mock_conversation_stats):
|
||||||
|
"""Connect the LLMRegistry and ConversationStats properly"""
|
||||||
|
# Subscribe to LLM registry events to track metrics
|
||||||
|
mock_llm_registry.subscribe(mock_conversation_stats.register_llm)
|
||||||
|
return mock_llm_registry, mock_conversation_stats
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def make_mock_agent():
|
||||||
|
def _make_mock_agent(llm_registry):
|
||||||
|
agent = MagicMock(spec=Agent)
|
||||||
|
agent_config = MagicMock(spec=AgentConfig)
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
agent_config.disabled_microagents = []
|
||||||
|
agent_config.enable_mcp = True
|
||||||
|
llm_registry.service_to_llm.clear()
|
||||||
|
mock_llm = llm_registry.get_llm('agent_llm', llm_config)
|
||||||
|
agent.llm = mock_llm
|
||||||
|
agent.name = 'test-agent'
|
||||||
|
agent.sandbox_plugins = []
|
||||||
|
agent.config = agent_config
|
||||||
|
agent.prompt_manager = MagicMock()
|
||||||
|
return agent
|
||||||
|
|
||||||
|
return _make_mock_agent
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_session_start_with_no_state(mock_agent):
|
async def test_agent_session_start_with_no_state(
|
||||||
|
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||||
|
):
|
||||||
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
||||||
|
mock_agent = make_mock_agent(mock_llm_registry)
|
||||||
# Setup
|
# Setup
|
||||||
file_store = InMemoryFileStore({})
|
file_store = InMemoryFileStore({})
|
||||||
session = AgentSession(
|
session = AgentSession(
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
file_store=file_store,
|
file_store=file_store,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
|
convo_stats=mock_conversation_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a mock runtime and set it up
|
# Create a mock runtime and set it up
|
||||||
@@ -140,13 +167,18 @@ async def test_agent_session_start_with_no_state(mock_agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_session_start_with_restored_state(mock_agent):
|
async def test_agent_session_start_with_restored_state(
|
||||||
|
make_mock_agent, mock_llm_registry, mock_conversation_stats
|
||||||
|
):
|
||||||
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
||||||
|
mock_agent = make_mock_agent(mock_llm_registry)
|
||||||
# Setup
|
# Setup
|
||||||
file_store = InMemoryFileStore({})
|
file_store = InMemoryFileStore({})
|
||||||
session = AgentSession(
|
session = AgentSession(
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
file_store=file_store,
|
file_store=file_store,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
|
convo_stats=mock_conversation_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a mock runtime and set it up
|
# Create a mock runtime and set it up
|
||||||
@@ -230,13 +262,21 @@ async def test_agent_session_start_with_restored_state(mock_agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_metrics_centralization_and_sharing(mock_agent):
|
async def test_metrics_centralization_via_conversation_stats(
|
||||||
"""Test that metrics are centralized and shared between controller and agent."""
|
make_mock_agent, connected_registry_and_stats
|
||||||
|
):
|
||||||
|
"""Test that metrics are centralized through the ConversationStats service."""
|
||||||
|
|
||||||
|
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||||
|
mock_agent = make_mock_agent(mock_llm_registry)
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
file_store = InMemoryFileStore({})
|
file_store = InMemoryFileStore({})
|
||||||
session = AgentSession(
|
session = AgentSession(
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
file_store=file_store,
|
file_store=file_store,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
|
convo_stats=mock_conversation_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a mock runtime and set it up
|
# Create a mock runtime and set it up
|
||||||
@@ -262,6 +302,8 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
|||||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||||
memory.microagents_dir = 'test-dir'
|
memory.microagents_dir = 'test-dir'
|
||||||
|
|
||||||
|
# The registry already has a real metrics object set up in the fixture
|
||||||
|
|
||||||
# Patch necessary components
|
# Patch necessary components
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -281,49 +323,50 @@ async def test_metrics_centralization_and_sharing(mock_agent):
|
|||||||
max_iterations=10,
|
max_iterations=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the agent's LLM metrics and controller's state metrics are the same object
|
# Verify that the ConversationStats is properly set up
|
||||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
assert session.controller.state.convo_stats is mock_conversation_stats
|
||||||
|
|
||||||
# Add some metrics to the agent's LLM
|
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||||
test_cost = 0.05
|
test_cost = 0.05
|
||||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||||
|
|
||||||
# Verify that the cost is reflected in the controller's state metrics
|
# Verify that the cost is reflected in the combined metrics from the conversation stats
|
||||||
assert session.controller.state.metrics.accumulated_cost == test_cost
|
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||||
|
assert combined_metrics.accumulated_cost == test_cost
|
||||||
|
|
||||||
# Create a test metrics object to simulate an observation with metrics
|
# Add more cost to simulate additional LLM usage
|
||||||
test_observation_metrics = Metrics()
|
additional_cost = 0.1
|
||||||
test_observation_metrics.add_cost(0.1)
|
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||||
|
|
||||||
# Get the current accumulated cost before merging
|
# Verify the combined metrics reflect the total cost
|
||||||
current_cost = session.controller.state.metrics.accumulated_cost
|
combined_metrics = session.controller.state.convo_stats.get_combined_metrics()
|
||||||
|
assert combined_metrics.accumulated_cost == test_cost + additional_cost
|
||||||
|
|
||||||
# Simulate merging metrics from an observation
|
# Reset the agent and verify that combined metrics are preserved
|
||||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
|
||||||
|
|
||||||
# Verify that the merged metrics are reflected in both agent and controller
|
|
||||||
assert session.controller.state.metrics.accumulated_cost == current_cost + 0.1
|
|
||||||
assert (
|
|
||||||
session.controller.agent.llm.metrics.accumulated_cost == current_cost + 0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reset the agent and verify that metrics are not reset
|
|
||||||
session.controller.agent.reset()
|
session.controller.agent.reset()
|
||||||
|
|
||||||
# Metrics should still be the same after reset
|
# Combined metrics should still be preserved after agent reset
|
||||||
assert session.controller.state.metrics.accumulated_cost == test_cost + 0.1
|
assert (
|
||||||
assert session.controller.agent.llm.metrics.accumulated_cost == test_cost + 0.1
|
session.controller.state.convo_stats.get_combined_metrics().accumulated_cost
|
||||||
assert session.controller.agent.llm.metrics is session.controller.state.metrics
|
== test_cost + additional_cost
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
async def test_budget_control_flag_syncs_with_metrics(
|
||||||
|
make_mock_agent, connected_registry_and_stats
|
||||||
|
):
|
||||||
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
|
"""Test that BudgetControlFlag's current value matches the accumulated costs."""
|
||||||
|
|
||||||
|
mock_llm_registry, mock_conversation_stats = connected_registry_and_stats
|
||||||
|
mock_agent = make_mock_agent(mock_llm_registry)
|
||||||
# Setup
|
# Setup
|
||||||
file_store = InMemoryFileStore({})
|
file_store = InMemoryFileStore({})
|
||||||
session = AgentSession(
|
session = AgentSession(
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
file_store=file_store,
|
file_store=file_store,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
|
convo_stats=mock_conversation_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a mock runtime and set it up
|
# Create a mock runtime and set it up
|
||||||
@@ -349,6 +392,8 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
|||||||
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
memory = Memory(event_stream=mock_event_stream, sid='test-session')
|
||||||
memory.microagents_dir = 'test-dir'
|
memory.microagents_dir = 'test-dir'
|
||||||
|
|
||||||
|
# The registry already has a real metrics object set up in the fixture
|
||||||
|
|
||||||
# Patch necessary components
|
# Patch necessary components
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -375,7 +420,7 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
|||||||
assert session.controller.state.budget_flag.max_value == 1.0
|
assert session.controller.state.budget_flag.max_value == 1.0
|
||||||
assert session.controller.state.budget_flag.current_value == 0.0
|
assert session.controller.state.budget_flag.current_value == 0.0
|
||||||
|
|
||||||
# Add some metrics to the agent's LLM
|
# Add some metrics to the agent's LLM (simulating LLM usage)
|
||||||
test_cost = 0.05
|
test_cost = 0.05
|
||||||
session.controller.agent.llm.metrics.add_cost(test_cost)
|
session.controller.agent.llm.metrics.add_cost(test_cost)
|
||||||
|
|
||||||
@@ -384,24 +429,31 @@ async def test_budget_control_flag_syncs_with_metrics(mock_agent):
|
|||||||
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||||
assert session.controller.state.budget_flag.current_value == test_cost
|
assert session.controller.state.budget_flag.current_value == test_cost
|
||||||
|
|
||||||
# Create a test metrics object to simulate an observation with metrics
|
# Add more cost to simulate additional LLM usage
|
||||||
test_observation_metrics = Metrics()
|
additional_cost = 0.1
|
||||||
test_observation_metrics.add_cost(0.1)
|
session.controller.agent.llm.metrics.add_cost(additional_cost)
|
||||||
|
|
||||||
# Simulate merging metrics from an observation
|
# Sync again and verify the budget flag is updated
|
||||||
session.controller.state_tracker.merge_metrics(test_observation_metrics)
|
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||||
|
assert (
|
||||||
|
session.controller.state.budget_flag.current_value
|
||||||
|
== test_cost + additional_cost
|
||||||
|
)
|
||||||
|
|
||||||
# Verify that the budget control flag's current value is updated to match the new accumulated cost
|
# Reset the agent and verify that budget flag still reflects the accumulated cost
|
||||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
|
||||||
|
|
||||||
# Reset the agent and verify that metrics and budget flag are not reset
|
|
||||||
session.controller.agent.reset()
|
session.controller.agent.reset()
|
||||||
|
|
||||||
# Budget control flag should still reflect the accumulated cost after reset
|
# Budget control flag should still reflect the accumulated cost after reset
|
||||||
assert session.controller.state.budget_flag.current_value == test_cost + 0.1
|
session.controller.state_tracker.sync_budget_flag_with_metrics()
|
||||||
|
assert (
|
||||||
|
session.controller.state.budget_flag.current_value
|
||||||
|
== test_cost + additional_cost
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_override_provider_tokens_with_custom_secret():
|
def test_override_provider_tokens_with_custom_secret(
|
||||||
|
mock_llm_registry, mock_conversation_stats
|
||||||
|
):
|
||||||
"""Test that override_provider_tokens_with_custom_secret works correctly.
|
"""Test that override_provider_tokens_with_custom_secret works correctly.
|
||||||
|
|
||||||
This test verifies that the method properly removes provider tokens when
|
This test verifies that the method properly removes provider tokens when
|
||||||
@@ -413,6 +465,8 @@ def test_override_provider_tokens_with_custom_secret():
|
|||||||
session = AgentSession(
|
session = AgentSession(
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
file_store=file_store,
|
file_store=file_store,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
|
convo_stats=mock_conversation_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from openhands.agenthub.readonly_agent.tools import (
|
|||||||
)
|
)
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.core.config import AgentConfig, LLMConfig
|
from openhands.core.config import AgentConfig, LLMConfig
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
from openhands.core.exceptions import FunctionCallNotExistsError
|
from openhands.core.exceptions import FunctionCallNotExistsError
|
||||||
from openhands.core.message import ImageContent, Message, TextContent
|
from openhands.core.message import ImageContent, Message, TextContent
|
||||||
from openhands.events.action import (
|
from openhands.events.action import (
|
||||||
@@ -42,10 +43,20 @@ from openhands.events.observation.commands import (
|
|||||||
CmdOutputObservation,
|
CmdOutputObservation,
|
||||||
)
|
)
|
||||||
from openhands.events.tool import ToolCallMetadata
|
from openhands.events.tool import ToolCallMetadata
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser import View
|
from openhands.memory.condenser import View
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_llm_registry():
|
||||||
|
def _get_registry(llm_config):
|
||||||
|
config = OpenHandsConfig()
|
||||||
|
config.set_llm_config(llm_config)
|
||||||
|
return LLMRegistry(config=config)
|
||||||
|
|
||||||
|
return _get_registry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
|
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
|
||||||
def agent_class(request):
|
def agent_class(request):
|
||||||
if request.param == 'CodeActAgent':
|
if request.param == 'CodeActAgent':
|
||||||
@@ -57,18 +68,22 @@ def agent_class(request):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent(agent_class) -> Union[CodeActAgent, ReadOnlyAgent]:
|
def agent(agent_class, create_llm_registry) -> Union[CodeActAgent, ReadOnlyAgent]:
|
||||||
|
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||||
config = AgentConfig()
|
config = AgentConfig()
|
||||||
agent = agent_class(llm=LLM(LLMConfig()), config=config)
|
agent = agent_class(config=config, llm_registry=create_llm_registry(llm_config))
|
||||||
agent.llm = Mock()
|
agent.llm = Mock()
|
||||||
agent.llm.config = Mock()
|
agent.llm.config = Mock()
|
||||||
agent.llm.config.max_message_chars = 1000
|
agent.llm.config.max_message_chars = 1000
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
def test_agent_with_default_config_has_default_tools():
|
def test_agent_with_default_config_has_default_tools(create_llm_registry):
|
||||||
|
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||||
config = AgentConfig()
|
config = AgentConfig()
|
||||||
codeact_agent = CodeActAgent(llm=LLM(LLMConfig()), config=config)
|
codeact_agent = CodeActAgent(
|
||||||
|
config=config, llm_registry=create_llm_registry(llm_config)
|
||||||
|
)
|
||||||
assert len(codeact_agent.tools) > 0
|
assert len(codeact_agent.tools) > 0
|
||||||
default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools]
|
default_tool_names = [tool['function']['name'] for tool in codeact_agent.tools]
|
||||||
assert {
|
assert {
|
||||||
@@ -231,7 +246,7 @@ def test_response_to_actions_invalid_tool():
|
|||||||
readonly_response_to_actions(mock_response)
|
readonly_response_to_actions(mock_response)
|
||||||
|
|
||||||
|
|
||||||
def test_step_with_no_pending_actions(mock_state: State):
|
def test_step_with_no_pending_actions(mock_state: State, create_llm_registry):
|
||||||
# Mock the LLM response
|
# Mock the LLM response
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.id = 'mock_id'
|
mock_response.id = 'mock_id'
|
||||||
@@ -252,9 +267,12 @@ def test_step_with_no_pending_actions(mock_state: State):
|
|||||||
llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting
|
llm.format_messages_for_llm = Mock(return_value=[]) # Mock message formatting
|
||||||
|
|
||||||
# Create agent with mocked LLM
|
# Create agent with mocked LLM
|
||||||
|
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||||
config = AgentConfig()
|
config = AgentConfig()
|
||||||
config.enable_prompt_extensions = False
|
config.enable_prompt_extensions = False
|
||||||
agent = CodeActAgent(llm=llm, config=config)
|
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
|
||||||
|
# Replace the LLM with our mock after creation
|
||||||
|
agent.llm = llm
|
||||||
|
|
||||||
# Test step with no pending actions
|
# Test step with no pending actions
|
||||||
mock_state.latest_user_message = None
|
mock_state.latest_user_message = None
|
||||||
@@ -281,15 +299,10 @@ def test_step_with_no_pending_actions(mock_state: State):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent'])
|
@pytest.mark.parametrize('agent_type', ['CodeActAgent', 'ReadOnlyAgent'])
|
||||||
def test_correct_tool_description_loaded_based_on_model_name(
|
def test_correct_tool_description_loaded_based_on_model_name(
|
||||||
agent_type, mock_state: State
|
agent_type, create_llm_registry
|
||||||
):
|
):
|
||||||
"""Tests that the simplified tool descriptions are loaded for specific models."""
|
"""Tests that the simplified tool descriptions are loaded for specific models."""
|
||||||
o3_mock_config = Mock()
|
o3_mock_config = LLMConfig(model='mock_o3_model', api_key='test_key')
|
||||||
o3_mock_config.model = 'mock_o3_model'
|
|
||||||
|
|
||||||
llm = Mock()
|
|
||||||
llm.config = o3_mock_config
|
|
||||||
|
|
||||||
if agent_type == 'CodeActAgent':
|
if agent_type == 'CodeActAgent':
|
||||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||||
|
|
||||||
@@ -299,16 +312,19 @@ def test_correct_tool_description_loaded_based_on_model_name(
|
|||||||
|
|
||||||
agent_class = ReadOnlyAgent
|
agent_class = ReadOnlyAgent
|
||||||
|
|
||||||
agent = agent_class(llm=llm, config=AgentConfig())
|
agent = agent_class(
|
||||||
|
config=AgentConfig(),
|
||||||
|
llm_registry=create_llm_registry(o3_mock_config),
|
||||||
|
)
|
||||||
for tool in agent.tools:
|
for tool in agent.tools:
|
||||||
# Assert all descriptions have less than 1024 characters
|
# Assert all descriptions have less than 1024 characters
|
||||||
assert len(tool['function']['description']) < 1024
|
assert len(tool['function']['description']) < 1024
|
||||||
|
|
||||||
sonnet_mock_config = Mock()
|
sonnect_mock_config = LLMConfig(model='mock_sonnet_model', api_key='test_key')
|
||||||
sonnet_mock_config.model = 'mock_sonnet_model'
|
agent = agent_class(
|
||||||
|
config=AgentConfig(),
|
||||||
llm.config = sonnet_mock_config
|
llm_registry=create_llm_registry(sonnect_mock_config),
|
||||||
agent = agent_class(llm=llm, config=AgentConfig())
|
)
|
||||||
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
|
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
|
||||||
if agent_type == 'CodeActAgent':
|
if agent_type == 'CodeActAgent':
|
||||||
# This only holds for CodeActAgent
|
# This only holds for CodeActAgent
|
||||||
@@ -481,10 +497,12 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages(
|
|||||||
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
assert isinstance(enhanced_messages[5].content[0], ImageContent)
|
||||||
|
|
||||||
|
|
||||||
def test_get_system_message():
|
def test_get_system_message(create_llm_registry):
|
||||||
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
|
"""Test that the Agent.get_system_message method returns a SystemMessageAction."""
|
||||||
# Create a mock agent
|
# Create a mock agent
|
||||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
config = AgentConfig()
|
||||||
|
llm_config = LLMConfig(model='gpt-4o', api_key='test_key')
|
||||||
|
agent = CodeActAgent(config=config, llm_registry=create_llm_registry(llm_config))
|
||||||
|
|
||||||
result = agent.get_system_message()
|
result = agent.get_system_message()
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def test_completion_retries_api_connection_error(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create an LLM instance and call completion
|
# Create an LLM instance and call completion
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
response = llm.completion(
|
response = llm.completion(
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -70,7 +70,7 @@ def test_completion_max_retries_api_connection_error(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create an LLM instance and call completion
|
# Create an LLM instance and call completion
|
||||||
llm = LLM(config=default_config)
|
llm = LLM(config=default_config, service_id='test-service')
|
||||||
|
|
||||||
# The completion should raise an APIConnectionError after exhausting all retries
|
# The completion should raise an APIConnectionError after exhausting all retries
|
||||||
with pytest.raises(APIConnectionError) as excinfo:
|
with pytest.raises(APIConnectionError) as excinfo:
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from openhands.core.config.llm_config import LLMConfig
|
|
||||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.event_store import EventStore
|
from openhands.events.event_store import EventStore
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
||||||
StandaloneConversationManager,
|
StandaloneConversationManager,
|
||||||
)
|
)
|
||||||
@@ -24,6 +24,7 @@ async def test_auto_generate_title_with_llm():
|
|||||||
"""Test auto-generating a title using LLM."""
|
"""Test auto-generating a title using LLM."""
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
file_store = InMemoryFileStore()
|
file_store = InMemoryFileStore()
|
||||||
|
llm_registry = MagicMock(spec=LLMRegistry)
|
||||||
|
|
||||||
# Create test conversation with a user message
|
# Create test conversation with a user message
|
||||||
conversation_id = 'test-conversation'
|
conversation_id = 'test-conversation'
|
||||||
@@ -46,43 +47,33 @@ async def test_auto_generate_title_with_llm():
|
|||||||
mock_event_store.search_events.return_value = [user_message]
|
mock_event_store.search_events.return_value = [user_message]
|
||||||
mock_event_store_cls.return_value = mock_event_store
|
mock_event_store_cls.return_value = mock_event_store
|
||||||
|
|
||||||
# Mock the LLM response
|
# Mock the LLM registry response
|
||||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
llm_registry.request_extraneous_completion.return_value = (
|
||||||
mock_llm = mock_llm_cls.return_value
|
'Python Data Analysis Script'
|
||||||
mock_response = MagicMock()
|
)
|
||||||
mock_response.choices = [MagicMock()]
|
|
||||||
mock_response.choices[0].message.content = 'Python Data Analysis Script'
|
|
||||||
mock_llm.completion.return_value = mock_response
|
|
||||||
|
|
||||||
# Create test settings with LLM config
|
# Create test settings with LLM config
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
llm_model='test-model',
|
llm_model='test-model',
|
||||||
llm_api_key='test-key',
|
llm_api_key='test-key',
|
||||||
llm_base_url='test-url',
|
llm_base_url='test-url',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the auto_generate_title function directly
|
# Call the auto_generate_title function directly
|
||||||
title = await auto_generate_title(
|
title = await auto_generate_title(
|
||||||
conversation_id, user_id, file_store, settings
|
conversation_id, user_id, file_store, settings, llm_registry
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the result
|
# Verify the result
|
||||||
assert title == 'Python Data Analysis Script'
|
assert title == 'Python Data Analysis Script'
|
||||||
|
|
||||||
# Verify EventStore was created with the correct parameters
|
# Verify EventStore was created with the correct parameters
|
||||||
mock_event_store_cls.assert_called_once_with(
|
mock_event_store_cls.assert_called_once_with(
|
||||||
conversation_id, file_store, user_id
|
conversation_id, file_store, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify LLM was called with appropriate parameters
|
# Verify LLM registry was called with appropriate parameters
|
||||||
mock_llm_cls.assert_called_once_with(
|
llm_registry.request_extraneous_completion.assert_called_once()
|
||||||
LLMConfig(
|
|
||||||
model='test-model',
|
|
||||||
api_key='test-key',
|
|
||||||
base_url='test-url',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mock_llm.completion.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -90,6 +81,7 @@ async def test_auto_generate_title_fallback():
|
|||||||
"""Test auto-generating a title with fallback to truncation when LLM fails."""
|
"""Test auto-generating a title with fallback to truncation when LLM fails."""
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
file_store = InMemoryFileStore()
|
file_store = InMemoryFileStore()
|
||||||
|
llm_registry = MagicMock(spec=LLMRegistry)
|
||||||
|
|
||||||
# Create test conversation with a user message
|
# Create test conversation with a user message
|
||||||
conversation_id = 'test-conversation'
|
conversation_id = 'test-conversation'
|
||||||
@@ -111,31 +103,29 @@ async def test_auto_generate_title_fallback():
|
|||||||
mock_event_store.search_events.return_value = [user_message]
|
mock_event_store.search_events.return_value = [user_message]
|
||||||
mock_event_store_cls.return_value = mock_event_store
|
mock_event_store_cls.return_value = mock_event_store
|
||||||
|
|
||||||
# Mock the LLM to raise an exception
|
# Mock the LLM registry to raise an exception
|
||||||
with patch('openhands.utils.conversation_summary.LLM') as mock_llm_cls:
|
llm_registry.request_extraneous_completion.side_effect = Exception('Test error')
|
||||||
mock_llm = mock_llm_cls.return_value
|
|
||||||
mock_llm.completion.side_effect = Exception('Test error')
|
|
||||||
|
|
||||||
# Create test settings with LLM config
|
# Create test settings with LLM config
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
llm_model='test-model',
|
llm_model='test-model',
|
||||||
llm_api_key='test-key',
|
llm_api_key='test-key',
|
||||||
llm_base_url='test-url',
|
llm_base_url='test-url',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the auto_generate_title function directly
|
# Call the auto_generate_title function directly
|
||||||
title = await auto_generate_title(
|
title = await auto_generate_title(
|
||||||
conversation_id, user_id, file_store, settings
|
conversation_id, user_id, file_store, settings, llm_registry
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the result is a truncated version of the message
|
# Verify the result is a truncated version of the message
|
||||||
assert title == 'This is a very long message th...'
|
assert title == 'This is a very long message th...'
|
||||||
assert len(title) <= 35
|
assert len(title) <= 35
|
||||||
|
|
||||||
# Verify EventStore was created with the correct parameters
|
# Verify EventStore was created with the correct parameters
|
||||||
mock_event_store_cls.assert_called_once_with(
|
mock_event_store_cls.assert_called_once_with(
|
||||||
conversation_id, file_store, user_id
|
conversation_id, file_store, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -143,6 +133,7 @@ async def test_auto_generate_title_no_messages():
|
|||||||
"""Test auto-generating a title when there are no user messages."""
|
"""Test auto-generating a title when there are no user messages."""
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
file_store = InMemoryFileStore()
|
file_store = InMemoryFileStore()
|
||||||
|
llm_registry = MagicMock(spec=LLMRegistry)
|
||||||
|
|
||||||
# Create test conversation with no messages
|
# Create test conversation with no messages
|
||||||
conversation_id = 'test-conversation'
|
conversation_id = 'test-conversation'
|
||||||
@@ -166,7 +157,7 @@ async def test_auto_generate_title_no_messages():
|
|||||||
|
|
||||||
# Call the auto_generate_title function directly
|
# Call the auto_generate_title function directly
|
||||||
title = await auto_generate_title(
|
title = await auto_generate_title(
|
||||||
conversation_id, user_id, file_store, settings
|
conversation_id, user_id, file_store, settings, llm_registry
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the result is empty
|
# Verify the result is empty
|
||||||
@@ -186,6 +177,7 @@ async def test_update_conversation_with_title():
|
|||||||
sio.emit = AsyncMock()
|
sio.emit = AsyncMock()
|
||||||
file_store = InMemoryFileStore()
|
file_store = InMemoryFileStore()
|
||||||
server_config = MagicMock()
|
server_config = MagicMock()
|
||||||
|
llm_registry = MagicMock(spec=LLMRegistry)
|
||||||
|
|
||||||
# Create test conversation
|
# Create test conversation
|
||||||
conversation_id = 'test-conversation'
|
conversation_id = 'test-conversation'
|
||||||
@@ -222,7 +214,9 @@ async def test_update_conversation_with_title():
|
|||||||
AsyncMock(return_value='Generated Title'),
|
AsyncMock(return_value='Generated Title'),
|
||||||
):
|
):
|
||||||
# Call the method
|
# Call the method
|
||||||
await manager._update_conversation_for_event(user_id, conversation_id, settings)
|
await manager._update_conversation_for_event(
|
||||||
|
user_id, conversation_id, settings, llm_registry
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the title was updated
|
# Verify the title was updated
|
||||||
assert mock_metadata.title == 'Generated Title'
|
assert mock_metadata.title == 'Generated Title'
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import pytest_asyncio
|
|||||||
|
|
||||||
from openhands.cli import main as cli
|
from openhands.cli import main as cli
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.events import EventSource
|
from openhands.events import EventSource
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
|
|
||||||
@@ -124,12 +125,14 @@ def mock_config():
|
|||||||
'' # Empty string, not starting with 'tvly-'
|
'' # Empty string, not starting with 'tvly-'
|
||||||
)
|
)
|
||||||
config.search_api_key = search_api_key_mock
|
config.search_api_key = search_api_key_mock
|
||||||
|
config.get_llm_config_from_agent.return_value = LLMConfig(model='model')
|
||||||
|
|
||||||
# Mock sandbox with volumes attribute to prevent finalize_config issues
|
# Mock sandbox with volumes attribute to prevent finalize_config issues
|
||||||
config.sandbox = MagicMock()
|
config.sandbox = MagicMock()
|
||||||
config.sandbox.volumes = (
|
config.sandbox.volumes = (
|
||||||
None # This prevents finalize_config from overriding workspace_base
|
None # This prevents finalize_config from overriding workspace_base
|
||||||
)
|
)
|
||||||
|
config.model_name = 'model'
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -213,7 +216,11 @@ async def test_run_session_without_initial_action(
|
|||||||
# Assertions for initialization flow
|
# Assertions for initialization flow
|
||||||
mock_display_runtime_init.assert_called_once_with('local')
|
mock_display_runtime_init.assert_called_once_with('local')
|
||||||
mock_display_animation.assert_called_once()
|
mock_display_animation.assert_called_once()
|
||||||
mock_create_agent.assert_called_once_with(mock_config)
|
# Check that mock_config is the first parameter to create_agent
|
||||||
|
mock_create_agent.assert_called_once()
|
||||||
|
assert mock_create_agent.call_args[0][0] == mock_config, (
|
||||||
|
'First parameter to create_agent should be mock_config'
|
||||||
|
)
|
||||||
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
|
mock_add_mcp_tools.assert_called_once_with(mock_agent, mock_runtime, mock_memory)
|
||||||
mock_create_runtime.assert_called_once()
|
mock_create_runtime.assert_called_once()
|
||||||
mock_create_controller.assert_called_once()
|
mock_create_controller.assert_called_once()
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from litellm.exceptions import AuthenticationError
|
from litellm.exceptions import AuthenticationError
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from openhands.cli import main as cli
|
from openhands.cli import main as cli
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.events import EventSource
|
from openhands.events import EventSource
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
|
|
||||||
@@ -45,11 +47,10 @@ def mock_config():
|
|||||||
config.workspace_base = '/test/dir'
|
config.workspace_base = '/test/dir'
|
||||||
|
|
||||||
# Set up LLM config to use OpenHands provider
|
# Set up LLM config to use OpenHands provider
|
||||||
llm_config = MagicMock()
|
llm_config = LLMConfig(model='openhands/o3', api_key=SecretStr('invalid-api-key'))
|
||||||
llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model
|
llm_config.model = 'openhands/o3' # Use OpenHands provider with o3 model
|
||||||
llm_config.api_key = MagicMock()
|
config.get_llm_config.return_value = llm_config
|
||||||
llm_config.api_key.get_secret_value.return_value = 'invalid-api-key'
|
config.get_llm_config_from_agent.return_value = llm_config
|
||||||
config.llm = llm_config
|
|
||||||
|
|
||||||
# Mock search_api_key with get_secret_value method
|
# Mock search_api_key with get_secret_value method
|
||||||
search_api_key_mock = MagicMock()
|
search_api_key_mock = MagicMock()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from openhands.core.config.mcp_config import (
|
|||||||
from openhands.events.action.mcp import MCPAction
|
from openhands.events.action.mcp import MCPAction
|
||||||
from openhands.events.observation import ErrorObservation
|
from openhands.events.observation import ErrorObservation
|
||||||
from openhands.events.observation.mcp import MCPObservation
|
from openhands.events.observation.mcp import MCPObservation
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||||
|
|
||||||
|
|
||||||
@@ -23,8 +24,12 @@ class TestCLIRuntimeMCP:
|
|||||||
"""Set up test fixtures."""
|
"""Set up test fixtures."""
|
||||||
self.config = OpenHandsConfig()
|
self.config = OpenHandsConfig()
|
||||||
self.event_stream = MagicMock()
|
self.event_stream = MagicMock()
|
||||||
|
llm_registry = LLMRegistry(config=OpenHandsConfig())
|
||||||
self.runtime = CLIRuntime(
|
self.runtime = CLIRuntime(
|
||||||
config=self.config, event_stream=self.event_stream, sid='test-session'
|
config=self.config,
|
||||||
|
event_stream=self.event_stream,
|
||||||
|
sid='test-session',
|
||||||
|
llm_registry=llm_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -7,10 +7,18 @@ import pytest
|
|||||||
|
|
||||||
from openhands.core.config import OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
|
|
||||||
|
# Mock LLMRegistry
|
||||||
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
from openhands.runtime.impl.cli.cli_runtime import CLIRuntime
|
||||||
from openhands.storage import get_file_store
|
from openhands.storage import get_file_store
|
||||||
|
|
||||||
|
|
||||||
|
# Create a mock LLMRegistry class
|
||||||
|
class MockLLMRegistry:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir():
|
||||||
"""Create a temporary directory for testing."""
|
"""Create a temporary directory for testing."""
|
||||||
@@ -25,7 +33,8 @@ def cli_runtime(temp_dir):
|
|||||||
event_stream = EventStream('test', file_store)
|
event_stream = EventStream('test', file_store)
|
||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
config.workspace_base = temp_dir
|
config.workspace_base = temp_dir
|
||||||
runtime = CLIRuntime(config, event_stream)
|
llm_registry = MockLLMRegistry(config)
|
||||||
|
runtime = CLIRuntime(config, event_stream, llm_registry)
|
||||||
runtime._runtime_initialized = True # Skip initialization
|
runtime._runtime_initialized = True # Skip initialization
|
||||||
return runtime
|
return runtime
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from openhands.core.config.condenser_config import (
|
|||||||
StructuredSummaryCondenserConfig,
|
StructuredSummaryCondenserConfig,
|
||||||
)
|
)
|
||||||
from openhands.core.config.llm_config import LLMConfig
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
from openhands.core.message import Message, TextContent
|
from openhands.core.message import Message, TextContent
|
||||||
from openhands.core.schema.action import ActionType
|
from openhands.core.schema.action import ActionType
|
||||||
from openhands.events.event import Event, EventSource
|
from openhands.events.event import Event, EventSource
|
||||||
@@ -24,6 +25,7 @@ from openhands.events.observation import BrowserOutputObservation
|
|||||||
from openhands.events.observation.agent import AgentCondensationObservation
|
from openhands.events.observation.agent import AgentCondensationObservation
|
||||||
from openhands.events.observation.observation import Observation
|
from openhands.events.observation.observation import Observation
|
||||||
from openhands.llm import LLM
|
from openhands.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.memory.condenser import Condenser
|
from openhands.memory.condenser import Condenser
|
||||||
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
from openhands.memory.condenser.condenser import Condensation, RollingCondenser, View
|
||||||
from openhands.memory.condenser.impl import (
|
from openhands.memory.condenser.impl import (
|
||||||
@@ -38,6 +40,7 @@ from openhands.memory.condenser.impl import (
|
|||||||
StructuredSummaryCondenser,
|
StructuredSummaryCondenser,
|
||||||
)
|
)
|
||||||
from openhands.memory.condenser.impl.pipeline import CondenserPipeline
|
from openhands.memory.condenser.impl.pipeline import CondenserPipeline
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
|
|
||||||
|
|
||||||
def create_test_event(
|
def create_test_event(
|
||||||
@@ -56,12 +59,15 @@ def create_test_event(
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llm() -> LLM:
|
def mock_llm() -> LLM:
|
||||||
"""Mocks an LLM object with a utility function for setting and resetting response contents in unit tests."""
|
"""Mocks an LLM object with a utility function for setting and resetting response contents in unit tests."""
|
||||||
|
# Create a real LLMConfig instead of a mock to properly handle SecretStr api_key
|
||||||
|
real_config = LLMConfig(
|
||||||
|
model='gpt-4o', api_key='test_key', custom_llm_provider=None
|
||||||
|
)
|
||||||
|
|
||||||
# Create a MagicMock for the LLM object
|
# Create a MagicMock for the LLM object
|
||||||
mock_llm = MagicMock(
|
mock_llm = MagicMock(
|
||||||
spec=LLM,
|
spec=LLM,
|
||||||
config=MagicMock(
|
config=real_config,
|
||||||
spec=LLMConfig, model='gpt-4o', api_key='test_key', custom_llm_provider=None
|
|
||||||
),
|
|
||||||
metrics=MagicMock(),
|
metrics=MagicMock(),
|
||||||
)
|
)
|
||||||
_mock_content = None
|
_mock_content = None
|
||||||
@@ -95,6 +101,23 @@ def mock_llm() -> LLM:
|
|||||||
return mock_llm
|
return mock_llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_conversation_stats() -> ConversationStats:
|
||||||
|
"""Creates a mock ConversationStats service."""
|
||||||
|
mock_stats = MagicMock(spec=ConversationStats)
|
||||||
|
return mock_stats
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_registry(mock_llm, mock_conversation_stats) -> LLMRegistry:
|
||||||
|
"""Creates an actual LLMRegistry that returns real LLMs."""
|
||||||
|
# Create an actual LLMRegistry with a basic OpenHandsConfig
|
||||||
|
config = OpenHandsConfig()
|
||||||
|
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
class RollingCondenserTestHarness:
|
class RollingCondenserTestHarness:
|
||||||
"""Test harness for rolling condensers.
|
"""Test harness for rolling condensers.
|
||||||
|
|
||||||
@@ -165,10 +188,10 @@ class RollingCondenserTestHarness:
|
|||||||
return ((index - max_size) // target_size) + 1
|
return ((index - max_size) // target_size) + 1
|
||||||
|
|
||||||
|
|
||||||
def test_noop_condenser_from_config():
|
def test_noop_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that the NoOpCondenser objects can be made from config."""
|
"""Test that the NoOpCondenser objects can be made from config."""
|
||||||
config = NoOpCondenserConfig()
|
config = NoOpCondenserConfig()
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, NoOpCondenser)
|
assert isinstance(condenser, NoOpCondenser)
|
||||||
|
|
||||||
@@ -189,11 +212,11 @@ def test_noop_condenser():
|
|||||||
assert result == View(events=events)
|
assert result == View(events=events)
|
||||||
|
|
||||||
|
|
||||||
def test_observation_masking_condenser_from_config():
|
def test_observation_masking_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that ObservationMaskingCondenser objects can be made from config."""
|
"""Test that ObservationMaskingCondenser objects can be made from config."""
|
||||||
attention_window = 5
|
attention_window = 5
|
||||||
config = ObservationMaskingCondenserConfig(attention_window=attention_window)
|
config = ObservationMaskingCondenserConfig(attention_window=attention_window)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, ObservationMaskingCondenser)
|
assert isinstance(condenser, ObservationMaskingCondenser)
|
||||||
assert condenser.attention_window == attention_window
|
assert condenser.attention_window == attention_window
|
||||||
@@ -229,11 +252,11 @@ def test_observation_masking_condenser_respects_attention_window():
|
|||||||
assert event == condensed_event
|
assert event == condensed_event
|
||||||
|
|
||||||
|
|
||||||
def test_browser_output_condenser_from_config():
|
def test_browser_output_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that BrowserOutputCondenser objects can be made from config."""
|
"""Test that BrowserOutputCondenser objects can be made from config."""
|
||||||
attention_window = 5
|
attention_window = 5
|
||||||
config = BrowserOutputCondenserConfig(attention_window=attention_window)
|
config = BrowserOutputCondenserConfig(attention_window=attention_window)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, BrowserOutputCondenser)
|
assert isinstance(condenser, BrowserOutputCondenser)
|
||||||
assert condenser.attention_window == attention_window
|
assert condenser.attention_window == attention_window
|
||||||
@@ -271,12 +294,12 @@ def test_browser_output_condenser_respects_attention_window():
|
|||||||
assert event == condensed_event
|
assert event == condensed_event
|
||||||
|
|
||||||
|
|
||||||
def test_recent_events_condenser_from_config():
|
def test_recent_events_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that RecentEventsCondenser objects can be made from config."""
|
"""Test that RecentEventsCondenser objects can be made from config."""
|
||||||
max_events = 5
|
max_events = 5
|
||||||
keep_first = True
|
keep_first = True
|
||||||
config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events)
|
config = RecentEventsCondenserConfig(keep_first=keep_first, max_events=max_events)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, RecentEventsCondenser)
|
assert isinstance(condenser, RecentEventsCondenser)
|
||||||
assert condenser.max_events == max_events
|
assert condenser.max_events == max_events
|
||||||
@@ -334,14 +357,14 @@ def test_recent_events_condenser():
|
|||||||
assert result[2]._message == 'Event 5' # kept from max_events
|
assert result[2]._message == 'Event 5' # kept from max_events
|
||||||
|
|
||||||
|
|
||||||
def test_llm_summarizing_condenser_from_config():
|
def test_llm_summarizing_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that LLMSummarizingCondenser objects can be made from config."""
|
"""Test that LLMSummarizingCondenser objects can be made from config."""
|
||||||
config = LLMSummarizingCondenserConfig(
|
config = LLMSummarizingCondenserConfig(
|
||||||
max_size=50,
|
max_size=50,
|
||||||
keep_first=10,
|
keep_first=10,
|
||||||
llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True),
|
llm_config=LLMConfig(model='gpt-4o', api_key='test_key', caching_prompt=True),
|
||||||
)
|
)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, LLMSummarizingCondenser)
|
assert isinstance(condenser, LLMSummarizingCondenser)
|
||||||
assert condenser.llm.config.model == 'gpt-4o'
|
assert condenser.llm.config.model == 'gpt-4o'
|
||||||
@@ -349,25 +372,33 @@ def test_llm_summarizing_condenser_from_config():
|
|||||||
assert condenser.max_size == 50
|
assert condenser.max_size == 50
|
||||||
assert condenser.keep_first == 10
|
assert condenser.keep_first == 10
|
||||||
|
|
||||||
# Since this condenser can't take advantage of caching, we intercept the
|
|
||||||
# passed config and manually flip the caching prompt to False.
|
|
||||||
assert not condenser.llm.config.caching_prompt
|
|
||||||
|
|
||||||
|
def test_llm_summarizing_condenser_invalid_config(mock_llm, mock_llm_registry):
|
||||||
def test_llm_summarizing_condenser_invalid_config():
|
|
||||||
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
|
"""Test that LLMSummarizingCondenser raises error when keep_first > max_size."""
|
||||||
pytest.raises(
|
pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
LLMSummarizingCondenser,
|
LLMSummarizingCondenser,
|
||||||
llm=MagicMock(),
|
llm=mock_llm,
|
||||||
max_size=4,
|
max_size=4,
|
||||||
keep_first=2,
|
keep_first=2,
|
||||||
)
|
)
|
||||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), max_size=0)
|
pytest.raises(
|
||||||
pytest.raises(ValueError, LLMSummarizingCondenser, llm=MagicMock(), keep_first=-1)
|
ValueError,
|
||||||
|
LLMSummarizingCondenser,
|
||||||
|
llm=mock_llm,
|
||||||
|
max_size=0,
|
||||||
|
)
|
||||||
|
pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
LLMSummarizingCondenser,
|
||||||
|
llm=mock_llm,
|
||||||
|
keep_first=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
def test_llm_summarizing_condenser_gives_expected_view_size(
|
||||||
|
mock_llm, mock_llm_registry
|
||||||
|
):
|
||||||
"""Test that LLMSummarizingCondenser maintains the correct view size."""
|
"""Test that LLMSummarizingCondenser maintains the correct view size."""
|
||||||
max_size = 10
|
max_size = 10
|
||||||
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
|
condenser = LLMSummarizingCondenser(max_size=max_size, llm=mock_llm)
|
||||||
@@ -383,12 +414,16 @@ def test_llm_summarizing_condenser_gives_expected_view_size(mock_llm):
|
|||||||
assert len(view) == harness.expected_size(i, max_size)
|
assert len(view) == harness.expected_size(i, max_size)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
def test_llm_summarizing_condenser_keeps_first_and_summary_events(
|
||||||
|
mock_llm, mock_llm_registry
|
||||||
|
):
|
||||||
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
|
"""Test that the LLM summarizing condenser appropriately maintains the event prefix and any summary events."""
|
||||||
max_size = 10
|
max_size = 10
|
||||||
keep_first = 3
|
keep_first = 3
|
||||||
condenser = LLMSummarizingCondenser(
|
condenser = LLMSummarizingCondenser(
|
||||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
max_size=max_size,
|
||||||
|
keep_first=keep_first,
|
||||||
|
llm=mock_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||||
@@ -412,14 +447,14 @@ def test_llm_summarizing_condenser_keeps_first_and_summary_events(mock_llm):
|
|||||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||||
|
|
||||||
|
|
||||||
def test_amortized_forgetting_condenser_from_config():
|
def test_amortized_forgetting_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that AmortizedForgettingCondenser objects can be made from config."""
|
"""Test that AmortizedForgettingCondenser objects can be made from config."""
|
||||||
max_size = 50
|
max_size = 50
|
||||||
keep_first = 10
|
keep_first = 10
|
||||||
config = AmortizedForgettingCondenserConfig(
|
config = AmortizedForgettingCondenserConfig(
|
||||||
max_size=max_size, keep_first=keep_first
|
max_size=max_size, keep_first=keep_first
|
||||||
)
|
)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, AmortizedForgettingCondenser)
|
assert isinstance(condenser, AmortizedForgettingCondenser)
|
||||||
assert condenser.max_size == max_size
|
assert condenser.max_size == max_size
|
||||||
@@ -475,7 +510,7 @@ def test_amortized_forgetting_condenser_keeps_first_and_last_events():
|
|||||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_from_config():
|
def test_llm_attention_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that LLMAttentionCondenser objects can be made from config."""
|
"""Test that LLMAttentionCondenser objects can be made from config."""
|
||||||
config = LLMAttentionCondenserConfig(
|
config = LLMAttentionCondenserConfig(
|
||||||
max_size=50,
|
max_size=50,
|
||||||
@@ -486,37 +521,32 @@ def test_llm_attention_condenser_from_config():
|
|||||||
caching_prompt=True,
|
caching_prompt=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, LLMAttentionCondenser)
|
assert isinstance(condenser, LLMAttentionCondenser)
|
||||||
assert condenser.llm.config.model == 'gpt-4o'
|
assert condenser.llm.config.model == 'gpt-4o'
|
||||||
assert condenser.llm.config.api_key.get_secret_value() == 'test_key'
|
|
||||||
assert condenser.max_size == 50
|
assert condenser.max_size == 50
|
||||||
assert condenser.keep_first == 10
|
assert condenser.keep_first == 10
|
||||||
|
|
||||||
# Since this condenser can't take advantage of caching, we intercept the
|
# Create a mock LLM that doesn't support function calling
|
||||||
# passed config and manually flip the caching prompt to False.
|
mock_llm = MagicMock()
|
||||||
assert not condenser.llm.config.caching_prompt
|
mock_llm.is_function_calling_active.return_value = False
|
||||||
|
|
||||||
|
# Create a new registry that returns our mock LLM that doesn't support function calling
|
||||||
|
mock_registry = MagicMock(spec=LLMRegistry)
|
||||||
|
mock_registry.get_llm.return_value = mock_llm
|
||||||
|
|
||||||
|
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config, mock_registry)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_invalid_config():
|
def test_llm_attention_condenser_gives_expected_view_size(mock_llm, mock_llm_registry):
|
||||||
"""Test that LLMAttentionCondenser raises an error if the configured LLM doesn't support response schema."""
|
|
||||||
config = LLMAttentionCondenserConfig(
|
|
||||||
max_size=50,
|
|
||||||
keep_first=10,
|
|
||||||
llm_config=LLMConfig(
|
|
||||||
model='claude-2', # Older model that doesn't support response schema
|
|
||||||
api_key='test_key',
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
pytest.raises(ValueError, LLMAttentionCondenser.from_config, config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
|
|
||||||
"""Test that the LLMAttentionCondenser gives views of the expected size."""
|
"""Test that the LLMAttentionCondenser gives views of the expected size."""
|
||||||
max_size = 10
|
max_size = 10
|
||||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
condenser = LLMAttentionCondenser(
|
||||||
|
max_size=max_size,
|
||||||
|
keep_first=0,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||||
|
|
||||||
@@ -534,10 +564,16 @@ def test_llm_attention_condenser_gives_expected_view_size(mock_llm):
|
|||||||
assert len(view) == harness.expected_size(i, max_size)
|
assert len(view) == harness.expected_size(i, max_size)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
def test_llm_attention_condenser_handles_events_outside_history(
|
||||||
|
mock_llm, mock_llm_registry
|
||||||
|
):
|
||||||
"""Test that the LLMAttentionCondenser handles event IDs that aren't from the event history."""
|
"""Test that the LLMAttentionCondenser handles event IDs that aren't from the event history."""
|
||||||
max_size = 2
|
max_size = 2
|
||||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
condenser = LLMAttentionCondenser(
|
||||||
|
max_size=max_size,
|
||||||
|
keep_first=0,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||||
|
|
||||||
@@ -555,10 +591,14 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm):
|
|||||||
assert len(view) == harness.expected_size(i, max_size)
|
assert len(view) == harness.expected_size(i, max_size)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_llm_registry):
|
||||||
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
|
"""Test that the LLMAttentionCondenser handles when the response contains too many event IDs."""
|
||||||
max_size = 2
|
max_size = 2
|
||||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
condenser = LLMAttentionCondenser(
|
||||||
|
max_size=max_size,
|
||||||
|
keep_first=0,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||||
|
|
||||||
@@ -576,12 +616,16 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm):
|
|||||||
assert len(view) == harness.expected_size(i, max_size)
|
assert len(view) == harness.expected_size(i, max_size)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_handles_too_few_events(mock_llm):
|
def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_llm_registry):
|
||||||
"""Test that the LLMAttentionCondenser handles when the response contains too few event IDs."""
|
"""Test that the LLMAttentionCondenser handles when the response contains too few event IDs."""
|
||||||
max_size = 2
|
max_size = 2
|
||||||
# Developer note: We must specify keep_first=0 because
|
# Developer note: We must specify keep_first=0 because
|
||||||
# keep_first (1) >= max_size//2 (1) is invalid.
|
# keep_first (1) >= max_size//2 (1) is invalid.
|
||||||
condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm)
|
condenser = LLMAttentionCondenser(
|
||||||
|
max_size=max_size,
|
||||||
|
keep_first=0,
|
||||||
|
llm=mock_llm,
|
||||||
|
)
|
||||||
|
|
||||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||||
|
|
||||||
@@ -597,12 +641,14 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm):
|
|||||||
assert len(view) == harness.expected_size(i, max_size)
|
assert len(view) == harness.expected_size(i, max_size)
|
||||||
|
|
||||||
|
|
||||||
def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
def test_llm_attention_condenser_handles_keep_first_events(mock_llm, mock_llm_registry):
|
||||||
"""Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size)."""
|
"""Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size)."""
|
||||||
max_size = 12
|
max_size = 12
|
||||||
keep_first = 4
|
keep_first = 4
|
||||||
condenser = LLMAttentionCondenser(
|
condenser = LLMAttentionCondenser(
|
||||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
max_size=max_size,
|
||||||
|
keep_first=keep_first,
|
||||||
|
llm=mock_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||||
@@ -620,7 +666,7 @@ def test_llm_attention_condenser_handles_keep_first_events(mock_llm):
|
|||||||
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
assert view[:keep_first] == events[: min(keep_first, i + 1)]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_summary_condenser_from_config():
|
def test_structured_summary_condenser_from_config(mock_llm_registry):
|
||||||
"""Test that StructuredSummaryCondenser objects can be made from config."""
|
"""Test that StructuredSummaryCondenser objects can be made from config."""
|
||||||
config = StructuredSummaryCondenserConfig(
|
config = StructuredSummaryCondenserConfig(
|
||||||
max_size=50,
|
max_size=50,
|
||||||
@@ -631,7 +677,7 @@ def test_structured_summary_condenser_from_config():
|
|||||||
caching_prompt=True,
|
caching_prompt=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, StructuredSummaryCondenser)
|
assert isinstance(condenser, StructuredSummaryCondenser)
|
||||||
assert condenser.llm.config.model == 'gpt-4o'
|
assert condenser.llm.config.model == 'gpt-4o'
|
||||||
@@ -639,40 +685,55 @@ def test_structured_summary_condenser_from_config():
|
|||||||
assert condenser.max_size == 50
|
assert condenser.max_size == 50
|
||||||
assert condenser.keep_first == 10
|
assert condenser.keep_first == 10
|
||||||
|
|
||||||
# Since this condenser can't take advantage of caching, we intercept the
|
|
||||||
# passed config and manually flip the caching prompt to False.
|
|
||||||
assert not condenser.llm.config.caching_prompt
|
|
||||||
|
|
||||||
|
def test_structured_summary_condenser_invalid_config(mock_llm):
|
||||||
def test_structured_summary_condenser_invalid_config():
|
|
||||||
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
|
"""Test that StructuredSummaryCondenser raises error when keep_first > max_size."""
|
||||||
# Since the condenser only works when function calling is on, we need to
|
# Since the condenser only works when function calling is on, we need to
|
||||||
# mock up the check for that.
|
# mock up the check for that.
|
||||||
llm = MagicMock()
|
mock_llm.is_function_calling_active.return_value = True
|
||||||
llm.is_function_calling_active.return_value = True
|
|
||||||
|
|
||||||
pytest.raises(
|
pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
StructuredSummaryCondenser,
|
StructuredSummaryCondenser,
|
||||||
llm=llm,
|
llm=mock_llm,
|
||||||
max_size=4,
|
max_size=4,
|
||||||
keep_first=2,
|
keep_first=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, max_size=0)
|
pytest.raises(
|
||||||
pytest.raises(ValueError, StructuredSummaryCondenser, llm=llm, keep_first=-1)
|
ValueError,
|
||||||
|
StructuredSummaryCondenser,
|
||||||
|
llm=mock_llm,
|
||||||
|
max_size=0,
|
||||||
|
)
|
||||||
|
pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
StructuredSummaryCondenser,
|
||||||
|
llm=mock_llm,
|
||||||
|
keep_first=-1,
|
||||||
|
)
|
||||||
|
|
||||||
# If all other parameters are good but there's no function calling the
|
# If all other parameters are good but there's no function calling the
|
||||||
# condenser still counts as improperly configured.
|
# condenser still counts as improperly configured.
|
||||||
llm.is_function_calling_active.return_value = False
|
# Create a mock LLM that doesn't support function calling
|
||||||
|
mock_llm_no_func = MagicMock()
|
||||||
|
mock_llm_no_func.is_function_calling_active.return_value = False
|
||||||
|
|
||||||
pytest.raises(
|
pytest.raises(
|
||||||
ValueError, StructuredSummaryCondenser, llm=llm, max_size=40, keep_first=2
|
ValueError,
|
||||||
|
StructuredSummaryCondenser,
|
||||||
|
llm=mock_llm_no_func,
|
||||||
|
max_size=40,
|
||||||
|
keep_first=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
def test_structured_summary_condenser_gives_expected_view_size(
|
||||||
|
mock_llm, mock_llm_registry
|
||||||
|
):
|
||||||
"""Test that StructuredSummaryCondenser maintains the correct view size."""
|
"""Test that StructuredSummaryCondenser maintains the correct view size."""
|
||||||
max_size = 10
|
max_size = 10
|
||||||
|
mock_llm.is_function_calling_active.return_value = True
|
||||||
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
|
condenser = StructuredSummaryCondenser(max_size=max_size, llm=mock_llm)
|
||||||
|
|
||||||
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
events = [create_test_event(f'Event {i}', id=i) for i in range(max_size * 10)]
|
||||||
@@ -686,12 +747,17 @@ def test_structured_summary_condenser_gives_expected_view_size(mock_llm):
|
|||||||
assert len(view) == harness.expected_size(i, max_size)
|
assert len(view) == harness.expected_size(i, max_size)
|
||||||
|
|
||||||
|
|
||||||
def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
def test_structured_summary_condenser_keeps_first_and_summary_events(
|
||||||
|
mock_llm, mock_llm_registry
|
||||||
|
):
|
||||||
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
|
"""Test that the StructuredSummaryCondenser appropriately maintains the event prefix and any summary events."""
|
||||||
max_size = 10
|
max_size = 10
|
||||||
keep_first = 3
|
keep_first = 3
|
||||||
|
mock_llm.is_function_calling_active.return_value = True
|
||||||
condenser = StructuredSummaryCondenser(
|
condenser = StructuredSummaryCondenser(
|
||||||
max_size=max_size, keep_first=keep_first, llm=mock_llm
|
max_size=max_size,
|
||||||
|
keep_first=keep_first,
|
||||||
|
llm=mock_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_llm.set_mock_response_content('Summary of forgotten events')
|
mock_llm.set_mock_response_content('Summary of forgotten events')
|
||||||
@@ -715,7 +781,7 @@ def test_structured_summary_condenser_keeps_first_and_summary_events(mock_llm):
|
|||||||
assert isinstance(view[keep_first], AgentCondensationObservation)
|
assert isinstance(view[keep_first], AgentCondensationObservation)
|
||||||
|
|
||||||
|
|
||||||
def test_condenser_pipeline_from_config():
|
def test_condenser_pipeline_from_config(mock_llm_registry):
|
||||||
"""Test that CondenserPipeline condensers can be created from configuration objects."""
|
"""Test that CondenserPipeline condensers can be created from configuration objects."""
|
||||||
config = CondenserPipelineConfig(
|
config = CondenserPipelineConfig(
|
||||||
condensers=[
|
condensers=[
|
||||||
@@ -728,7 +794,7 @@ def test_condenser_pipeline_from_config():
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
condenser = Condenser.from_config(config)
|
condenser = Condenser.from_config(config, mock_llm_registry)
|
||||||
|
|
||||||
assert isinstance(condenser, CondenserPipeline)
|
assert isinstance(condenser, CondenserPipeline)
|
||||||
assert len(condenser.condensers) == 3
|
assert len(condenser.condensers) == 3
|
||||||
|
|||||||
490
tests/unit/test_conversation_stats.py
Normal file
490
tests/unit/test_conversation_stats.py
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||||
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||||
|
from openhands.llm.metrics import Metrics
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_file_store():
|
||||||
|
"""Create a mock file store for testing."""
|
||||||
|
return InMemoryFileStore({})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_stats(mock_file_store):
|
||||||
|
"""Create a ConversationStats instance for testing."""
|
||||||
|
return ConversationStats(
|
||||||
|
file_store=mock_file_store,
|
||||||
|
conversation_id='test-conversation-id',
|
||||||
|
user_id='test-user-id',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_registry():
|
||||||
|
"""Create a mock LLM registry that properly simulates LLM registration."""
|
||||||
|
config = OpenHandsConfig()
|
||||||
|
registry = LLMRegistry(config=config, agent_cls=None, retry_listener=None)
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def connected_registry_and_stats(mock_llm_registry, conversation_stats):
|
||||||
|
"""Connect the LLMRegistry and ConversationStats properly."""
|
||||||
|
# Subscribe to LLM registry events to track metrics
|
||||||
|
mock_llm_registry.subscribe(conversation_stats.register_llm)
|
||||||
|
return mock_llm_registry, conversation_stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_stats_initialization(conversation_stats):
|
||||||
|
"""Test that ConversationStats initializes correctly."""
|
||||||
|
assert conversation_stats.conversation_id == 'test-conversation-id'
|
||||||
|
assert conversation_stats.user_id == 'test-user-id'
|
||||||
|
assert conversation_stats.service_to_metrics == {}
|
||||||
|
assert isinstance(conversation_stats.restored_metrics, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_metrics(conversation_stats, mock_file_store):
|
||||||
|
"""Test that metrics are saved correctly."""
|
||||||
|
# Add a service with metrics
|
||||||
|
service_id = 'test-service'
|
||||||
|
metrics = Metrics(model_name='gpt-4')
|
||||||
|
metrics.add_cost(0.05)
|
||||||
|
conversation_stats.service_to_metrics[service_id] = metrics
|
||||||
|
|
||||||
|
# Save metrics
|
||||||
|
conversation_stats.save_metrics()
|
||||||
|
|
||||||
|
# Verify that metrics were saved to the file store
|
||||||
|
try:
|
||||||
|
# Verify the saved content can be decoded and unpickled
|
||||||
|
encoded = mock_file_store.read(conversation_stats.metrics_path)
|
||||||
|
pickled = base64.b64decode(encoded)
|
||||||
|
restored = pickle.loads(pickled)
|
||||||
|
|
||||||
|
assert service_id in restored
|
||||||
|
assert restored[service_id].accumulated_cost == 0.05
|
||||||
|
except FileNotFoundError:
|
||||||
|
pytest.fail(f'File not found: {conversation_stats.metrics_path}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_maybe_restore_metrics(mock_file_store):
|
||||||
|
"""Test that metrics are restored correctly."""
|
||||||
|
# Create metrics to save
|
||||||
|
service_id = 'test-service'
|
||||||
|
metrics = Metrics(model_name='gpt-4')
|
||||||
|
metrics.add_cost(0.1)
|
||||||
|
service_to_metrics = {service_id: metrics}
|
||||||
|
|
||||||
|
# Serialize and save metrics
|
||||||
|
pickled = pickle.dumps(service_to_metrics)
|
||||||
|
serialized_metrics = base64.b64encode(pickled).decode('utf-8')
|
||||||
|
|
||||||
|
# Create a new ConversationStats with pre-populated file store
|
||||||
|
conversation_id = 'test-conversation-id'
|
||||||
|
user_id = 'test-user-id'
|
||||||
|
|
||||||
|
# Get the correct path using the same function as ConversationStats
|
||||||
|
from openhands.storage.locations import get_conversation_stats_filename
|
||||||
|
|
||||||
|
metrics_path = get_conversation_stats_filename(conversation_id, user_id)
|
||||||
|
|
||||||
|
# Write to the correct path
|
||||||
|
mock_file_store.write(metrics_path, serialized_metrics)
|
||||||
|
|
||||||
|
# Create ConversationStats which should restore metrics
|
||||||
|
stats = ConversationStats(
|
||||||
|
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify metrics were restored
|
||||||
|
assert service_id in stats.restored_metrics
|
||||||
|
assert stats.restored_metrics[service_id].accumulated_cost == 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_combined_metrics(conversation_stats):
|
||||||
|
"""Test that combined metrics are calculated correctly."""
|
||||||
|
# Add multiple services with metrics
|
||||||
|
service1 = 'service1'
|
||||||
|
metrics1 = Metrics(model_name='gpt-4')
|
||||||
|
metrics1.add_cost(0.05)
|
||||||
|
metrics1.add_token_usage(
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=8000,
|
||||||
|
response_id='resp1',
|
||||||
|
)
|
||||||
|
|
||||||
|
service2 = 'service2'
|
||||||
|
metrics2 = Metrics(model_name='gpt-3.5')
|
||||||
|
metrics2.add_cost(0.02)
|
||||||
|
metrics2.add_token_usage(
|
||||||
|
prompt_tokens=200,
|
||||||
|
completion_tokens=100,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=4000,
|
||||||
|
response_id='resp2',
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_stats.service_to_metrics[service1] = metrics1
|
||||||
|
conversation_stats.service_to_metrics[service2] = metrics2
|
||||||
|
|
||||||
|
# Get combined metrics
|
||||||
|
combined = conversation_stats.get_combined_metrics()
|
||||||
|
|
||||||
|
# Verify combined metrics
|
||||||
|
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
|
||||||
|
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
|
||||||
|
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
|
||||||
|
assert (
|
||||||
|
combined.accumulated_token_usage.context_window == 8000
|
||||||
|
) # max of 8000 and 4000
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_metrics_for_service(conversation_stats):
|
||||||
|
"""Test that metrics for a specific service are retrieved correctly."""
|
||||||
|
# Add a service with metrics
|
||||||
|
service_id = 'test-service'
|
||||||
|
metrics = Metrics(model_name='gpt-4')
|
||||||
|
metrics.add_cost(0.05)
|
||||||
|
conversation_stats.service_to_metrics[service_id] = metrics
|
||||||
|
|
||||||
|
# Get metrics for the service
|
||||||
|
retrieved_metrics = conversation_stats.get_metrics_for_service(service_id)
|
||||||
|
|
||||||
|
# Verify metrics
|
||||||
|
assert retrieved_metrics.accumulated_cost == 0.05
|
||||||
|
assert retrieved_metrics is metrics # Should be the same object
|
||||||
|
|
||||||
|
# Test getting metrics for non-existent service
|
||||||
|
# Use a specific exception message pattern instead of a blind Exception
|
||||||
|
with pytest.raises(Exception, match='LLM service does not exist'):
|
||||||
|
conversation_stats.get_metrics_for_service('non-existent-service')
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_llm_with_new_service(conversation_stats):
|
||||||
|
"""Test registering a new LLM service."""
|
||||||
|
# Create a real LLM instance with a mock config
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch the LLM class to avoid actual API calls
|
||||||
|
with patch('openhands.llm.llm.litellm_completion'):
|
||||||
|
llm = LLM(service_id='new-service', config=llm_config)
|
||||||
|
|
||||||
|
# Create a registry event
|
||||||
|
service_id = 'new-service'
|
||||||
|
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||||
|
|
||||||
|
# Register the LLM
|
||||||
|
conversation_stats.register_llm(event)
|
||||||
|
|
||||||
|
# Verify the service was registered
|
||||||
|
assert service_id in conversation_stats.service_to_metrics
|
||||||
|
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_llm_with_restored_metrics(conversation_stats):
|
||||||
|
"""Test registering an LLM service with restored metrics."""
|
||||||
|
# Create restored metrics
|
||||||
|
service_id = 'restored-service'
|
||||||
|
restored_metrics = Metrics(model_name='gpt-4')
|
||||||
|
restored_metrics.add_cost(0.1)
|
||||||
|
conversation_stats.restored_metrics = {service_id: restored_metrics}
|
||||||
|
|
||||||
|
# Create a real LLM instance with a mock config
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch the LLM class to avoid actual API calls
|
||||||
|
with patch('openhands.llm.llm.litellm_completion'):
|
||||||
|
llm = LLM(service_id=service_id, config=llm_config)
|
||||||
|
|
||||||
|
# Create a registry event
|
||||||
|
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||||
|
|
||||||
|
# Register the LLM
|
||||||
|
conversation_stats.register_llm(event)
|
||||||
|
|
||||||
|
# Verify the service was registered with restored metrics
|
||||||
|
assert service_id in conversation_stats.service_to_metrics
|
||||||
|
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||||
|
assert llm.metrics.accumulated_cost == 0.1 # Restored cost
|
||||||
|
|
||||||
|
# Verify the specific service was removed from restored_metrics
|
||||||
|
assert service_id not in conversation_stats.restored_metrics
|
||||||
|
assert hasattr(
|
||||||
|
conversation_stats, 'restored_metrics'
|
||||||
|
) # The dict should still exist
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_registry_notifications(connected_registry_and_stats):
|
||||||
|
"""Test that LLM registry notifications update conversation stats."""
|
||||||
|
mock_llm_registry, conversation_stats = connected_registry_and_stats
|
||||||
|
|
||||||
|
# Create a new LLM through the registry
|
||||||
|
service_id = 'test-service'
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get LLM from registry (this should trigger the notification)
|
||||||
|
llm = mock_llm_registry.get_llm(service_id, llm_config)
|
||||||
|
|
||||||
|
# Verify the service was registered in conversation stats
|
||||||
|
assert service_id in conversation_stats.service_to_metrics
|
||||||
|
assert conversation_stats.service_to_metrics[service_id] is llm.metrics
|
||||||
|
|
||||||
|
# Add some metrics to the LLM
|
||||||
|
llm.metrics.add_cost(0.05)
|
||||||
|
llm.metrics.add_token_usage(
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=8000,
|
||||||
|
response_id='resp1',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the metrics are reflected in conversation stats
|
||||||
|
assert conversation_stats.service_to_metrics[service_id].accumulated_cost == 0.05
|
||||||
|
assert (
|
||||||
|
conversation_stats.service_to_metrics[
|
||||||
|
service_id
|
||||||
|
].accumulated_token_usage.prompt_tokens
|
||||||
|
== 100
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
conversation_stats.service_to_metrics[
|
||||||
|
service_id
|
||||||
|
].accumulated_token_usage.completion_tokens
|
||||||
|
== 50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get combined metrics and verify
|
||||||
|
combined = conversation_stats.get_combined_metrics()
|
||||||
|
assert combined.accumulated_cost == 0.05
|
||||||
|
assert combined.accumulated_token_usage.prompt_tokens == 100
|
||||||
|
assert combined.accumulated_token_usage.completion_tokens == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_llm_services(connected_registry_and_stats):
|
||||||
|
"""Test tracking metrics for multiple LLM services."""
|
||||||
|
mock_llm_registry, conversation_stats = connected_registry_and_stats
|
||||||
|
|
||||||
|
# Create multiple LLMs through the registry
|
||||||
|
service1 = 'service1'
|
||||||
|
service2 = 'service2'
|
||||||
|
|
||||||
|
llm_config1 = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_config2 = LLMConfig(
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get LLMs from registry (this should trigger notifications)
|
||||||
|
llm1 = mock_llm_registry.get_llm(service1, llm_config1)
|
||||||
|
llm2 = mock_llm_registry.get_llm(service2, llm_config2)
|
||||||
|
|
||||||
|
# Add different metrics to each LLM
|
||||||
|
llm1.metrics.add_cost(0.05)
|
||||||
|
llm1.metrics.add_token_usage(
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=8000,
|
||||||
|
response_id='resp1',
|
||||||
|
)
|
||||||
|
|
||||||
|
llm2.metrics.add_cost(0.02)
|
||||||
|
llm2.metrics.add_token_usage(
|
||||||
|
prompt_tokens=200,
|
||||||
|
completion_tokens=100,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=4000,
|
||||||
|
response_id='resp2',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify services were registered in conversation stats
|
||||||
|
assert service1 in conversation_stats.service_to_metrics
|
||||||
|
assert service2 in conversation_stats.service_to_metrics
|
||||||
|
|
||||||
|
# Verify individual metrics
|
||||||
|
assert conversation_stats.service_to_metrics[service1].accumulated_cost == 0.05
|
||||||
|
assert conversation_stats.service_to_metrics[service2].accumulated_cost == 0.02
|
||||||
|
|
||||||
|
# Get combined metrics and verify
|
||||||
|
combined = conversation_stats.get_combined_metrics()
|
||||||
|
assert combined.accumulated_cost == 0.07 # 0.05 + 0.02
|
||||||
|
assert combined.accumulated_token_usage.prompt_tokens == 300 # 100 + 200
|
||||||
|
assert combined.accumulated_token_usage.completion_tokens == 150 # 50 + 100
|
||||||
|
assert (
|
||||||
|
combined.accumulated_token_usage.context_window == 8000
|
||||||
|
) # max of 8000 and 4000
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_llm_with_multiple_restored_services_bug(conversation_stats):
|
||||||
|
"""Test that reproduces the bug where del self.restored_metrics deletes entire dict instead of specific service."""
|
||||||
|
# Create restored metrics for multiple services
|
||||||
|
service_id_1 = 'service-1'
|
||||||
|
service_id_2 = 'service-2'
|
||||||
|
|
||||||
|
restored_metrics_1 = Metrics(model_name='gpt-4')
|
||||||
|
restored_metrics_1.add_cost(0.1)
|
||||||
|
|
||||||
|
restored_metrics_2 = Metrics(model_name='gpt-3.5')
|
||||||
|
restored_metrics_2.add_cost(0.05)
|
||||||
|
|
||||||
|
# Set up restored metrics for both services
|
||||||
|
conversation_stats.restored_metrics = {
|
||||||
|
service_id_1: restored_metrics_1,
|
||||||
|
service_id_2: restored_metrics_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create LLM configs
|
||||||
|
llm_config_1 = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_config_2 = LLMConfig(
|
||||||
|
model='gpt-3.5-turbo',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch the LLM class to avoid actual API calls
|
||||||
|
with patch('openhands.llm.llm.litellm_completion'):
|
||||||
|
# Register first LLM
|
||||||
|
llm_1 = LLM(service_id=service_id_1, config=llm_config_1)
|
||||||
|
event_1 = RegistryEvent(llm=llm_1, service_id=service_id_1)
|
||||||
|
conversation_stats.register_llm(event_1)
|
||||||
|
|
||||||
|
# Verify first service was registered with restored metrics
|
||||||
|
assert service_id_1 in conversation_stats.service_to_metrics
|
||||||
|
assert llm_1.metrics.accumulated_cost == 0.1
|
||||||
|
|
||||||
|
# After registering first service, restored_metrics should still contain service_id_2
|
||||||
|
assert service_id_2 in conversation_stats.restored_metrics
|
||||||
|
|
||||||
|
# Register second LLM - this should also work with restored metrics
|
||||||
|
llm_2 = LLM(service_id=service_id_2, config=llm_config_2)
|
||||||
|
event_2 = RegistryEvent(llm=llm_2, service_id=service_id_2)
|
||||||
|
conversation_stats.register_llm(event_2)
|
||||||
|
|
||||||
|
# Verify second service was registered with restored metrics
|
||||||
|
assert service_id_2 in conversation_stats.service_to_metrics
|
||||||
|
assert llm_2.metrics.accumulated_cost == 0.05
|
||||||
|
|
||||||
|
# After both services are registered, restored_metrics should be empty
|
||||||
|
assert len(conversation_stats.restored_metrics) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_restore_workflow(mock_file_store):
|
||||||
|
"""Test the full workflow of saving and restoring metrics."""
|
||||||
|
# Create initial conversation stats
|
||||||
|
conversation_id = 'test-conversation-id'
|
||||||
|
user_id = 'test-user-id'
|
||||||
|
|
||||||
|
stats1 = ConversationStats(
|
||||||
|
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a service with metrics
|
||||||
|
service_id = 'test-service'
|
||||||
|
metrics = Metrics(model_name='gpt-4')
|
||||||
|
metrics.add_cost(0.05)
|
||||||
|
metrics.add_token_usage(
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=8000,
|
||||||
|
response_id='resp1',
|
||||||
|
)
|
||||||
|
stats1.service_to_metrics[service_id] = metrics
|
||||||
|
|
||||||
|
# Save metrics
|
||||||
|
stats1.save_metrics()
|
||||||
|
|
||||||
|
# Create a new conversation stats instance that should restore the metrics
|
||||||
|
stats2 = ConversationStats(
|
||||||
|
file_store=mock_file_store, conversation_id=conversation_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify metrics were restored
|
||||||
|
assert service_id in stats2.restored_metrics
|
||||||
|
assert stats2.restored_metrics[service_id].accumulated_cost == 0.05
|
||||||
|
assert (
|
||||||
|
stats2.restored_metrics[service_id].accumulated_token_usage.prompt_tokens == 100
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
stats2.restored_metrics[service_id].accumulated_token_usage.completion_tokens
|
||||||
|
== 50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a real LLM instance with a mock config
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model='gpt-4o',
|
||||||
|
api_key='test_key',
|
||||||
|
num_retries=2,
|
||||||
|
retry_min_wait=1,
|
||||||
|
retry_max_wait=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch the LLM class to avoid actual API calls
|
||||||
|
with patch('openhands.llm.llm.litellm_completion'):
|
||||||
|
llm = LLM(service_id=service_id, config=llm_config)
|
||||||
|
|
||||||
|
# Create a registry event
|
||||||
|
event = RegistryEvent(llm=llm, service_id=service_id)
|
||||||
|
|
||||||
|
# Register the LLM to trigger restoration
|
||||||
|
stats2.register_llm(event)
|
||||||
|
|
||||||
|
# Verify metrics were applied to the LLM
|
||||||
|
assert llm.metrics.accumulated_cost == 0.05
|
||||||
|
assert llm.metrics.accumulated_token_usage.prompt_tokens == 100
|
||||||
|
assert llm.metrics.accumulated_token_usage.completion_tokens == 50
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for the conversation summary generator."""
|
"""Tests for the conversation summary generator."""
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -11,55 +11,51 @@ from openhands.utils.conversation_summary import generate_conversation_title
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_conversation_title_empty_message():
|
async def test_generate_conversation_title_empty_message():
|
||||||
"""Test that an empty message returns None."""
|
"""Test that an empty message returns None."""
|
||||||
result = await generate_conversation_title('', MagicMock())
|
mock_llm_registry = MagicMock()
|
||||||
|
mock_llm_config = LLMConfig(model='test-model')
|
||||||
|
|
||||||
|
result = await generate_conversation_title('', mock_llm_config, mock_llm_registry)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
result = await generate_conversation_title(' ', MagicMock())
|
result = await generate_conversation_title(
|
||||||
|
' ', mock_llm_config, mock_llm_registry
|
||||||
|
)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_conversation_title_success():
|
async def test_generate_conversation_title_success():
|
||||||
"""Test successful title generation."""
|
"""Test successful title generation."""
|
||||||
# Create a proper mock response
|
# Create a mock LLM registry that returns a title
|
||||||
mock_response = MagicMock()
|
mock_llm_registry = MagicMock()
|
||||||
mock_response.choices = [MagicMock()]
|
mock_llm_registry.request_extraneous_completion.return_value = 'Generated Title'
|
||||||
mock_response.choices[0].message.content = 'Generated Title'
|
|
||||||
|
|
||||||
# Create a mock LLM instance with a synchronous completion method
|
mock_llm_config = LLMConfig(model='test-model')
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.completion = MagicMock(return_value=mock_response)
|
|
||||||
|
|
||||||
# Patch the LLM class to return our mock
|
result = await generate_conversation_title(
|
||||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
'Can you help me with Python?', mock_llm_config, mock_llm_registry
|
||||||
result = await generate_conversation_title(
|
)
|
||||||
'Can you help me with Python?', LLMConfig(model='test-model')
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == 'Generated Title'
|
assert result == 'Generated Title'
|
||||||
# Verify the mock was called with the expected arguments
|
# Verify the mock was called with the expected arguments
|
||||||
mock_llm.completion.assert_called_once()
|
mock_llm_registry.request_extraneous_completion.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_conversation_title_long_title():
|
async def test_generate_conversation_title_long_title():
|
||||||
"""Test that long titles are truncated."""
|
"""Test that long titles are truncated."""
|
||||||
# Create a proper mock response with a long title
|
# Create a mock LLM registry that returns a long title
|
||||||
mock_response = MagicMock()
|
mock_llm_registry = MagicMock()
|
||||||
mock_response.choices = [MagicMock()]
|
mock_llm_registry.request_extraneous_completion.return_value = 'This is a very long title that should be truncated because it exceeds the maximum length'
|
||||||
mock_response.choices[
|
|
||||||
0
|
|
||||||
].message.content = 'This is a very long title that should be truncated because it exceeds the maximum length'
|
|
||||||
|
|
||||||
# Create a mock LLM instance with a synchronous completion method
|
mock_llm_config = LLMConfig(model='test-model')
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.completion = MagicMock(return_value=mock_response)
|
|
||||||
|
|
||||||
# Patch the LLM class to return our mock
|
result = await generate_conversation_title(
|
||||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
'Can you help me with Python?',
|
||||||
result = await generate_conversation_title(
|
mock_llm_config,
|
||||||
'Can you help me with Python?', LLMConfig(model='test-model'), max_length=30
|
mock_llm_registry,
|
||||||
)
|
max_length=30,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the title is truncated correctly
|
# Verify the title is truncated correctly
|
||||||
assert len(result) <= 30
|
assert len(result) <= 30
|
||||||
@@ -69,15 +65,17 @@ async def test_generate_conversation_title_long_title():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_conversation_title_exception():
|
async def test_generate_conversation_title_exception():
|
||||||
"""Test that exceptions are handled gracefully."""
|
"""Test that exceptions are handled gracefully."""
|
||||||
# Create a mock LLM instance with a synchronous completion method that raises an exception
|
# Create a mock LLM registry that raises an exception
|
||||||
mock_llm = MagicMock()
|
mock_llm_registry = MagicMock()
|
||||||
mock_llm.completion = MagicMock(side_effect=Exception('Test error'))
|
mock_llm_registry.request_extraneous_completion.side_effect = Exception(
|
||||||
|
'Test error'
|
||||||
|
)
|
||||||
|
|
||||||
# Patch the LLM class to return our mock
|
mock_llm_config = LLMConfig(model='test-model')
|
||||||
with patch('openhands.utils.conversation_summary.LLM', return_value=mock_llm):
|
|
||||||
result = await generate_conversation_title(
|
result = await generate_conversation_title(
|
||||||
'Can you help me with Python?', LLMConfig(model='test-model')
|
'Can you help me with Python?', mock_llm_config, mock_llm_registry
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that None is returned when an exception occurs
|
# Verify that None is returned when an exception occurs
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
|
|
||||||
from openhands.core.config import OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
from openhands.runtime.impl.docker.docker_runtime import DockerRuntime
|
||||||
|
|
||||||
|
|
||||||
@@ -40,12 +41,17 @@ def event_stream():
|
|||||||
return MagicMock(spec=EventStream)
|
return MagicMock(spec=EventStream)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_registry():
|
||||||
|
return MagicMock(spec=LLMRegistry)
|
||||||
|
|
||||||
|
|
||||||
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
||||||
def test_container_stopped_when_keep_runtime_alive_false(
|
def test_container_stopped_when_keep_runtime_alive_false(
|
||||||
mock_stop_containers, mock_docker_client, config, event_stream
|
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
runtime = DockerRuntime(config, event_stream, sid='test-sid')
|
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
|
||||||
runtime.container = mock_docker_client.containers.get.return_value
|
runtime.container = mock_docker_client.containers.get.return_value
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@@ -57,11 +63,11 @@ def test_container_stopped_when_keep_runtime_alive_false(
|
|||||||
|
|
||||||
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
@patch('openhands.runtime.impl.docker.docker_runtime.stop_all_containers')
|
||||||
def test_container_not_stopped_when_keep_runtime_alive_true(
|
def test_container_not_stopped_when_keep_runtime_alive_true(
|
||||||
mock_stop_containers, mock_docker_client, config, event_stream
|
mock_stop_containers, mock_docker_client, config, event_stream, llm_registry
|
||||||
):
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
config.sandbox.keep_runtime_alive = True
|
config.sandbox.keep_runtime_alive = True
|
||||||
runtime = DockerRuntime(config, event_stream, sid='test-sid')
|
runtime = DockerRuntime(config, event_stream, llm_registry, sid='test-sid')
|
||||||
runtime.container = mock_docker_client.containers.get.return_value
|
runtime.container = mock_docker_client.containers.get.return_value
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|||||||
178
tests/unit/test_llm_registry.py
Normal file
178
tests/unit/test_llm_registry.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMRegistry(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test environment before each test."""
|
||||||
|
# Create a basic LLM config for testing
|
||||||
|
self.llm_config = LLMConfig(model='test-model')
|
||||||
|
|
||||||
|
# Create a basic OpenHands config for testing
|
||||||
|
self.config = OpenHandsConfig(
|
||||||
|
llms={'llm': self.llm_config}, default_agent='CodeActAgent'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a registry for testing
|
||||||
|
self.registry = LLMRegistry(config=self.config)
|
||||||
|
|
||||||
|
def test_get_llm_creates_new_llm(self):
|
||||||
|
"""Test that get_llm creates a new LLM when service doesn't exist."""
|
||||||
|
service_id = 'test-service'
|
||||||
|
|
||||||
|
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||||
|
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.config = self.llm_config
|
||||||
|
mock_create.return_value = mock_llm
|
||||||
|
|
||||||
|
# Get LLM for the first time
|
||||||
|
llm = self.registry.get_llm(service_id, self.llm_config)
|
||||||
|
|
||||||
|
# Verify LLM was created and stored
|
||||||
|
self.assertEqual(llm, mock_llm)
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
config=self.llm_config, service_id=service_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_llm_returns_existing_llm(self):
|
||||||
|
"""Test that get_llm returns existing LLM when service already exists."""
|
||||||
|
service_id = 'test-service'
|
||||||
|
|
||||||
|
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||||
|
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.config = self.llm_config
|
||||||
|
mock_create.return_value = mock_llm
|
||||||
|
|
||||||
|
# Get LLM for the first time
|
||||||
|
llm1 = self.registry.get_llm(service_id, self.llm_config)
|
||||||
|
|
||||||
|
# Manually add to registry to simulate existing LLM
|
||||||
|
self.registry.service_to_llm[service_id] = mock_llm
|
||||||
|
|
||||||
|
# Get LLM for the second time - should return the same instance
|
||||||
|
llm2 = self.registry.get_llm(service_id, self.llm_config)
|
||||||
|
|
||||||
|
# Verify same LLM instance is returned
|
||||||
|
self.assertEqual(llm1, llm2)
|
||||||
|
self.assertEqual(llm1, mock_llm)
|
||||||
|
|
||||||
|
# Verify _create_new_llm was only called once
|
||||||
|
mock_create.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_llm_with_different_config_raises_error(self):
|
||||||
|
"""Test that requesting same service ID with different config raises an error."""
|
||||||
|
service_id = 'test-service'
|
||||||
|
different_config = LLMConfig(model='different-model')
|
||||||
|
|
||||||
|
# Manually add an LLM to the registry to simulate existing service
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.config = self.llm_config
|
||||||
|
self.registry.service_to_llm[service_id] = mock_llm
|
||||||
|
|
||||||
|
# Attempt to get LLM with different config should raise ValueError
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
self.registry.get_llm(service_id, different_config)
|
||||||
|
|
||||||
|
self.assertIn('Requesting same service ID', str(context.exception))
|
||||||
|
self.assertIn('with different config', str(context.exception))
|
||||||
|
|
||||||
|
def test_get_llm_without_config_raises_error(self):
|
||||||
|
"""Test that requesting new LLM without config raises an error."""
|
||||||
|
service_id = 'test-service'
|
||||||
|
|
||||||
|
# Attempt to get LLM without providing config should raise ValueError
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
self.registry.get_llm(service_id, None)
|
||||||
|
|
||||||
|
self.assertIn(
|
||||||
|
'Requesting new LLM without specifying LLM config', str(context.exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_request_extraneous_completion(self):
|
||||||
|
"""Test that requesting an extraneous completion creates a new LLM if needed."""
|
||||||
|
service_id = 'extraneous-service'
|
||||||
|
messages = [{'role': 'user', 'content': 'Hello, world!'}]
|
||||||
|
|
||||||
|
# Mock the _create_new_llm method to avoid actual LLM initialization
|
||||||
|
with patch.object(self.registry, '_create_new_llm') as mock_create:
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = ' Hello from the LLM! '
|
||||||
|
mock_llm.completion.return_value = mock_response
|
||||||
|
mock_create.return_value = mock_llm
|
||||||
|
|
||||||
|
# Mock the side effect to add the LLM to the registry
|
||||||
|
def side_effect(*args, **kwargs):
|
||||||
|
self.registry.service_to_llm[service_id] = mock_llm
|
||||||
|
return mock_llm
|
||||||
|
|
||||||
|
mock_create.side_effect = side_effect
|
||||||
|
|
||||||
|
# Request a completion
|
||||||
|
response = self.registry.request_extraneous_completion(
|
||||||
|
service_id=service_id,
|
||||||
|
llm_config=self.llm_config,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response (should be stripped)
|
||||||
|
self.assertEqual(response, 'Hello from the LLM!')
|
||||||
|
|
||||||
|
# Verify that _create_new_llm was called with correct parameters
|
||||||
|
mock_create.assert_called_once_with(
|
||||||
|
config=self.llm_config, service_id=service_id, with_listener=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify completion was called with correct messages
|
||||||
|
mock_llm.completion.assert_called_once_with(messages=messages)
|
||||||
|
|
||||||
|
def test_get_active_llm(self):
|
||||||
|
"""Test that get_active_llm returns the active agent LLM."""
|
||||||
|
active_llm = self.registry.get_active_llm()
|
||||||
|
self.assertEqual(active_llm, self.registry.active_agent_llm)
|
||||||
|
|
||||||
|
def test_subscribe_and_notify(self):
|
||||||
|
"""Test the subscription and notification system."""
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
def callback(event: RegistryEvent):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
# Subscribe to events
|
||||||
|
self.registry.subscribe(callback)
|
||||||
|
|
||||||
|
# Should receive notification for the active agent LLM
|
||||||
|
self.assertEqual(len(events_received), 1)
|
||||||
|
self.assertEqual(events_received[0].llm, self.registry.active_agent_llm)
|
||||||
|
self.assertEqual(
|
||||||
|
events_received[0].service_id, self.registry.active_agent_llm.service_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that the subscriber is set correctly
|
||||||
|
self.assertIsNotNone(self.registry.subscriber)
|
||||||
|
|
||||||
|
# Test notify method directly with a mock event
|
||||||
|
with patch.object(self.registry, 'subscriber') as mock_subscriber:
|
||||||
|
mock_event = MagicMock()
|
||||||
|
self.registry.notify(mock_event)
|
||||||
|
mock_subscriber.assert_called_once_with(mock_event)
|
||||||
|
|
||||||
|
def test_registry_has_unique_id(self):
|
||||||
|
"""Test that each registry instance has a unique ID."""
|
||||||
|
registry2 = LLMRegistry(config=self.config)
|
||||||
|
self.assertNotEqual(self.registry.registry_id, registry2.registry_id)
|
||||||
|
self.assertTrue(len(self.registry.registry_id) > 0)
|
||||||
|
self.assertTrue(len(registry2.registry_id) > 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@@ -12,6 +12,8 @@ from openhands.core.config.mcp_config import (
|
|||||||
MCPSSEServerConfig,
|
MCPSSEServerConfig,
|
||||||
MCPStdioServerConfig,
|
MCPStdioServerConfig,
|
||||||
)
|
)
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||||
from openhands.server.session.session import Session
|
from openhands.server.session.session import Session
|
||||||
from openhands.storage.memory import InMemoryFileStore
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
@@ -428,6 +430,8 @@ async def test_session_preserves_env_mcp_config(monkeypatch):
|
|||||||
file_store=InMemoryFileStore({}),
|
file_store=InMemoryFileStore({}),
|
||||||
config=config,
|
config=config,
|
||||||
sio=AsyncMock(),
|
sio=AsyncMock(),
|
||||||
|
llm_registry=LLMRegistry(config=OpenHandsConfig()),
|
||||||
|
convo_stats=ConversationStats(None, 'test-sid', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create empty settings
|
# Create empty settings
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import pytest
|
|||||||
from mcp import McpError
|
from mcp import McpError
|
||||||
|
|
||||||
from openhands.controller.agent import Agent
|
from openhands.controller.agent import Agent
|
||||||
from openhands.controller.agent_controller import AgentController, AgentState
|
from openhands.controller.agent_controller import AgentController
|
||||||
|
from openhands.core.schema import AgentState
|
||||||
from openhands.events.action.mcp import MCPAction
|
from openhands.events.action.mcp import MCPAction
|
||||||
from openhands.events.action.message import SystemMessageAction
|
from openhands.events.action.message import SystemMessageAction
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
@@ -17,6 +18,8 @@ from openhands.events.stream import EventStream
|
|||||||
from openhands.mcp.client import MCPClient
|
from openhands.mcp.client import MCPClient
|
||||||
from openhands.mcp.tool import MCPClientTool
|
from openhands.mcp.tool import MCPClientTool
|
||||||
from openhands.mcp.utils import call_tool_mcp
|
from openhands.mcp.utils import call_tool_mcp
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
|
|
||||||
|
|
||||||
class MockConfig:
|
class MockConfig:
|
||||||
@@ -34,6 +37,11 @@ class MockLLM:
|
|||||||
self.config = MockConfig()
|
self.config = MockConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def convo_stats():
|
||||||
|
return ConversationStats(None, 'convo-id', None)
|
||||||
|
|
||||||
|
|
||||||
class MockAgent(Agent):
|
class MockAgent(Agent):
|
||||||
"""Mock agent for testing."""
|
"""Mock agent for testing."""
|
||||||
|
|
||||||
@@ -53,7 +61,7 @@ class MockAgent(Agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mcp_tool_timeout_error_handling():
|
async def test_mcp_tool_timeout_error_handling(convo_stats):
|
||||||
"""Test that verifies MCP tool timeout errors are properly handled and returned as observations."""
|
"""Test that verifies MCP tool timeout errors are properly handled and returned as observations."""
|
||||||
# Create a mock MCPClient
|
# Create a mock MCPClient
|
||||||
mock_client = mock.MagicMock(spec=MCPClient)
|
mock_client = mock.MagicMock(spec=MCPClient)
|
||||||
@@ -80,7 +88,7 @@ async def test_mcp_tool_timeout_error_handling():
|
|||||||
mock_client.tool_map = {'test_tool': mock_tool}
|
mock_client.tool_map = {'test_tool': mock_tool}
|
||||||
|
|
||||||
# Create a mock file store
|
# Create a mock file store
|
||||||
mock_file_store = mock.MagicMock()
|
mock_file_store = InMemoryFileStore({})
|
||||||
|
|
||||||
# Create a mock event stream
|
# Create a mock event stream
|
||||||
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
||||||
@@ -90,13 +98,12 @@ async def test_mcp_tool_timeout_error_handling():
|
|||||||
|
|
||||||
# Create a mock agent controller
|
# Create a mock agent controller
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
sid='test-session',
|
|
||||||
file_store=mock_file_store,
|
|
||||||
user_id='test-user',
|
|
||||||
agent=agent,
|
agent=agent,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
convo_stats=convo_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
budget_per_task_delta=None,
|
budget_per_task_delta=None,
|
||||||
|
sid='test-session',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up the agent state
|
# Set up the agent state
|
||||||
@@ -143,7 +150,7 @@ async def test_mcp_tool_timeout_error_handling():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mcp_tool_timeout_agent_continuation():
|
async def test_mcp_tool_timeout_agent_continuation(convo_stats):
|
||||||
"""Test that verifies the agent can continue processing after an MCP tool timeout."""
|
"""Test that verifies the agent can continue processing after an MCP tool timeout."""
|
||||||
# Create a mock MCPClient
|
# Create a mock MCPClient
|
||||||
mock_client = mock.MagicMock(spec=MCPClient)
|
mock_client = mock.MagicMock(spec=MCPClient)
|
||||||
@@ -170,7 +177,7 @@ async def test_mcp_tool_timeout_agent_continuation():
|
|||||||
mock_client.tool_map = {'test_tool': mock_tool}
|
mock_client.tool_map = {'test_tool': mock_tool}
|
||||||
|
|
||||||
# Create a mock file store
|
# Create a mock file store
|
||||||
mock_file_store = mock.MagicMock()
|
mock_file_store = InMemoryFileStore({})
|
||||||
|
|
||||||
# Create a mock event stream
|
# Create a mock event stream
|
||||||
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
event_stream = EventStream(sid='test-session', file_store=mock_file_store)
|
||||||
@@ -180,13 +187,12 @@ async def test_mcp_tool_timeout_agent_continuation():
|
|||||||
|
|
||||||
# Create a mock agent controller
|
# Create a mock agent controller
|
||||||
controller = AgentController(
|
controller = AgentController(
|
||||||
sid='test-session',
|
|
||||||
file_store=mock_file_store,
|
|
||||||
user_id='test-user',
|
|
||||||
agent=agent,
|
agent=agent,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
convo_stats=convo_stats,
|
||||||
iteration_delta=10,
|
iteration_delta=10,
|
||||||
budget_per_task_delta=None,
|
budget_per_task_delta=None,
|
||||||
|
sid='test-session',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up the agent state
|
# Set up the agent state
|
||||||
|
|||||||
@@ -21,11 +21,13 @@ from openhands.events.observation.agent import (
|
|||||||
from openhands.events.serialization.observation import observation_from_dict
|
from openhands.events.serialization.observation import observation_from_dict
|
||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
from openhands.llm import LLM
|
from openhands.llm import LLM
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.llm.metrics import Metrics
|
from openhands.llm.metrics import Metrics
|
||||||
from openhands.memory.memory import Memory
|
from openhands.memory.memory import Memory
|
||||||
from openhands.runtime.impl.action_execution.action_execution_client import (
|
from openhands.runtime.impl.action_execution.action_execution_client import (
|
||||||
ActionExecutionClient,
|
ActionExecutionClient,
|
||||||
)
|
)
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.server.session.agent_session import AgentSession
|
from openhands.server.session.agent_session import AgentSession
|
||||||
from openhands.storage.memory import InMemoryFileStore
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
from openhands.utils.prompt import (
|
from openhands.utils.prompt import (
|
||||||
@@ -42,6 +44,12 @@ def file_store():
|
|||||||
return InMemoryFileStore({})
|
return InMemoryFileStore({})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_registry(file_store):
|
||||||
|
"""Create a mock LLMRegistry for testing."""
|
||||||
|
return MagicMock(spec=LLMRegistry)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def event_stream(file_store):
|
def event_stream(file_store):
|
||||||
"""Create a test event stream."""
|
"""Create a test event stream."""
|
||||||
@@ -90,24 +98,29 @@ def mock_agent():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_memory_on_event_exception_handling(memory, event_stream, mock_agent):
|
async def test_memory_on_event_exception_handling(
|
||||||
|
memory, event_stream, mock_agent, mock_llm_registry
|
||||||
|
):
|
||||||
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
"""Test that exceptions in Memory.on_event are properly handled via status callback."""
|
||||||
# Create a mock runtime
|
# Create a mock runtime
|
||||||
runtime = MagicMock(spec=ActionExecutionClient)
|
runtime = MagicMock(spec=ActionExecutionClient)
|
||||||
runtime.event_stream = event_stream
|
runtime.event_stream = event_stream
|
||||||
|
|
||||||
# Mock Memory method to raise an exception
|
# Mock Memory method to raise an exception
|
||||||
with patch.object(
|
with (
|
||||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
patch.object(
|
||||||
|
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||||
|
),
|
||||||
|
patch('openhands.core.main.create_agent', return_value=mock_agent),
|
||||||
):
|
):
|
||||||
state = await run_controller(
|
state = await run_controller(
|
||||||
config=OpenHandsConfig(),
|
config=OpenHandsConfig(),
|
||||||
initial_user_action=MessageAction(content='Test message'),
|
initial_user_action=MessageAction(content='Test message'),
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
sid='test',
|
sid='test',
|
||||||
agent=mock_agent,
|
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
memory=memory,
|
memory=memory,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the controller's last error was set
|
# Verify that the controller's last error was set
|
||||||
@@ -118,7 +131,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream, mock_age
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_memory_on_workspace_context_recall_exception_handling(
|
async def test_memory_on_workspace_context_recall_exception_handling(
|
||||||
memory, event_stream, mock_agent
|
memory, event_stream, mock_agent, mock_llm_registry
|
||||||
):
|
):
|
||||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||||
# Create a mock runtime
|
# Create a mock runtime
|
||||||
@@ -126,19 +139,22 @@ async def test_memory_on_workspace_context_recall_exception_handling(
|
|||||||
runtime.event_stream = event_stream
|
runtime.event_stream = event_stream
|
||||||
|
|
||||||
# Mock Memory._on_workspace_context_recall to raise an exception
|
# Mock Memory._on_workspace_context_recall to raise an exception
|
||||||
with patch.object(
|
with (
|
||||||
memory,
|
patch.object(
|
||||||
'_find_microagent_knowledge',
|
memory,
|
||||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
'_find_microagent_knowledge',
|
||||||
|
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||||
|
),
|
||||||
|
patch('openhands.core.main.create_agent', return_value=mock_agent),
|
||||||
):
|
):
|
||||||
state = await run_controller(
|
state = await run_controller(
|
||||||
config=OpenHandsConfig(),
|
config=OpenHandsConfig(),
|
||||||
initial_user_action=MessageAction(content='Test message'),
|
initial_user_action=MessageAction(content='Test message'),
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
sid='test',
|
sid='test',
|
||||||
agent=mock_agent,
|
|
||||||
fake_user_response_fn=lambda _: 'repeat',
|
fake_user_response_fn=lambda _: 'repeat',
|
||||||
memory=memory,
|
memory=memory,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the controller's last error was set
|
# Verify that the controller's last error was set
|
||||||
@@ -593,12 +609,14 @@ REPOSITORY INSTRUCTIONS: This is the second test repository.
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_conversation_instructions_plumbed_to_memory(
|
async def test_conversation_instructions_plumbed_to_memory(
|
||||||
mock_agent, event_stream, file_store
|
mock_agent, event_stream, file_store, mock_llm_registry
|
||||||
):
|
):
|
||||||
# Setup
|
# Setup
|
||||||
session = AgentSession(
|
session = AgentSession(
|
||||||
sid='test-session',
|
sid='test-session',
|
||||||
file_store=file_store,
|
file_store=file_store,
|
||||||
|
llm_registry=mock_llm_registry,
|
||||||
|
convo_stats=ConversationStats(file_store, 'test-session', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a mock runtime and set it up
|
# Create a mock runtime and set it up
|
||||||
|
|||||||
@@ -3,26 +3,30 @@ from litellm import ModelResponse
|
|||||||
|
|
||||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||||
from openhands.core.config import AgentConfig, LLMConfig
|
from openhands.core.config import AgentConfig, LLMConfig
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
from openhands.events.action import MessageAction
|
from openhands.events.action import MessageAction
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llm():
|
def llm_config():
|
||||||
llm = LLM(
|
return LLMConfig(
|
||||||
LLMConfig(
|
model='claude-3-5-sonnet-20241022',
|
||||||
model='claude-3-5-sonnet-20241022',
|
api_key='fake',
|
||||||
api_key='fake',
|
caching_prompt=True,
|
||||||
caching_prompt=True,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def codeact_agent(mock_llm):
|
def llm_registry():
|
||||||
|
registry = LLMRegistry(config=OpenHandsConfig())
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def codeact_agent(llm_registry):
|
||||||
config = AgentConfig()
|
config = AgentConfig()
|
||||||
agent = CodeActAgent(mock_llm, config)
|
agent = CodeActAgent(config, llm_registry)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,14 +12,26 @@ from openhands.events.observation import NullObservation, Observation
|
|||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
|
from openhands.integrations.provider import ProviderHandler, ProviderToken, ProviderType
|
||||||
from openhands.integrations.service_types import AuthenticationError, Repository
|
from openhands.integrations.service_types import AuthenticationError, Repository
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.storage import get_file_store
|
from openhands.storage import get_file_store
|
||||||
|
|
||||||
|
|
||||||
class TestRuntime(Runtime):
|
class MockRuntime(Runtime):
|
||||||
"""A concrete implementation of Runtime for testing"""
|
"""A concrete implementation of Runtime for testing"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Ensure llm_registry is provided if not already in kwargs
|
||||||
|
if 'llm_registry' not in kwargs and len(args) < 3:
|
||||||
|
# Create a mock LLMRegistry if not provided
|
||||||
|
config = (
|
||||||
|
kwargs.get('config')
|
||||||
|
if 'config' in kwargs
|
||||||
|
else args[0]
|
||||||
|
if args
|
||||||
|
else OpenHandsConfig()
|
||||||
|
)
|
||||||
|
kwargs['llm_registry'] = LLMRegistry(config=config)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.run_action_calls = []
|
self.run_action_calls = []
|
||||||
self._execute_shell_fn_git_handler = MagicMock(
|
self._execute_shell_fn_git_handler = MagicMock(
|
||||||
@@ -89,9 +101,11 @@ def runtime(temp_dir):
|
|||||||
)
|
)
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(
|
llm_registry = LLMRegistry(config=config)
|
||||||
|
runtime = MockRuntime(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
|
llm_registry=llm_registry,
|
||||||
sid='test',
|
sid='test',
|
||||||
user_id='test_user',
|
user_id='test_user',
|
||||||
git_provider_tokens=git_provider_tokens,
|
git_provider_tokens=git_provider_tokens,
|
||||||
@@ -119,7 +133,7 @@ async def test_export_latest_git_provider_tokens_no_user_id(temp_dir):
|
|||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(config=config, event_stream=event_stream, sid='test')
|
runtime = MockRuntime(config=config, event_stream=event_stream, sid='test')
|
||||||
|
|
||||||
# Create a command that would normally trigger token export
|
# Create a command that would normally trigger token export
|
||||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||||
@@ -137,7 +151,7 @@ async def test_export_latest_git_provider_tokens_no_token_ref(temp_dir):
|
|||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -177,7 +191,7 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
|
|||||||
)
|
)
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -225,7 +239,7 @@ async def test_clone_or_init_repo_no_repo_init_git_in_empty_workspace(temp_dir):
|
|||||||
config.init_git_in_empty_workspace = True
|
config.init_git_in_empty_workspace = True
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -249,7 +263,7 @@ async def test_clone_or_init_repo_no_repo_no_user_id_with_workspace_base(temp_di
|
|||||||
config.workspace_base = '/some/path' # Set workspace_base
|
config.workspace_base = '/some/path' # Set workspace_base
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config, event_stream=event_stream, sid='test', user_id=None
|
config=config, event_stream=event_stream, sid='test', user_id=None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -267,7 +281,7 @@ async def test_clone_or_init_repo_auth_error(temp_dir):
|
|||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -298,7 +312,7 @@ async def test_clone_or_init_repo_github_with_token(temp_dir, monkeypatch):
|
|||||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
|
{ProviderType.GITHUB: ProviderToken(token=SecretStr(github_token))}
|
||||||
)
|
)
|
||||||
|
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -336,7 +350,7 @@ async def test_clone_or_init_repo_github_no_token(temp_dir, monkeypatch):
|
|||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
|
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -371,7 +385,7 @@ async def test_clone_or_init_repo_gitlab_with_token(temp_dir, monkeypatch):
|
|||||||
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
|
{ProviderType.GITLAB: ProviderToken(token=SecretStr(gitlab_token))}
|
||||||
)
|
)
|
||||||
|
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config,
|
config=config,
|
||||||
event_stream=event_stream,
|
event_stream=event_stream,
|
||||||
sid='test',
|
sid='test',
|
||||||
@@ -410,7 +424,7 @@ async def test_clone_or_init_repo_with_branch(temp_dir, monkeypatch):
|
|||||||
file_store = get_file_store('local', temp_dir)
|
file_store = get_file_store('local', temp_dir)
|
||||||
event_stream = EventStream('abc', file_store)
|
event_stream = EventStream('abc', file_store)
|
||||||
|
|
||||||
runtime = TestRuntime(
|
runtime = MockRuntime(
|
||||||
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
config=config, event_stream=event_stream, sid='test', user_id='test_user'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ import pytest
|
|||||||
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
from openhands.integrations.service_types import ProviderType, Repository
|
from openhands.integrations.service_types import ProviderType, Repository
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.microagent.microagent import (
|
from openhands.microagent.microagent import (
|
||||||
RepoMicroagent,
|
RepoMicroagent,
|
||||||
)
|
)
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
|
from openhands.storage import get_file_store
|
||||||
|
|
||||||
|
|
||||||
class MockRuntime(Runtime):
|
class MockRuntime(Runtime):
|
||||||
@@ -24,12 +26,21 @@ class MockRuntime(Runtime):
|
|||||||
config.workspace_mount_path_in_sandbox = str(workspace_root)
|
config.workspace_mount_path_in_sandbox = str(workspace_root)
|
||||||
config.sandbox = SandboxConfig()
|
config.sandbox = SandboxConfig()
|
||||||
|
|
||||||
# Create a mock event stream
|
# Create a mock event stream and file store
|
||||||
|
file_store = get_file_store('local', str(workspace_root))
|
||||||
event_stream = MagicMock(spec=EventStream)
|
event_stream = MagicMock(spec=EventStream)
|
||||||
|
event_stream.file_store = file_store
|
||||||
|
|
||||||
|
# Create a mock LLM registry
|
||||||
|
llm_registry = LLMRegistry(config)
|
||||||
|
|
||||||
# Initialize the parent class properly
|
# Initialize the parent class properly
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config=config, event_stream=event_stream, sid='test', git_provider_tokens={}
|
config=config,
|
||||||
|
event_stream=event_stream,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
sid='test',
|
||||||
|
git_provider_tokens={},
|
||||||
)
|
)
|
||||||
|
|
||||||
self._workspace_root = workspace_root
|
self._workspace_root = workspace_root
|
||||||
|
|||||||
@@ -595,7 +595,7 @@ async def test_check_usertask(
|
|||||||
analyzer = InvariantAnalyzer(event_stream)
|
analyzer = InvariantAnalyzer(event_stream)
|
||||||
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
mock_response = {'choices': [{'message': {'content': is_appropriate}}]}
|
||||||
mock_litellm_completion.return_value = mock_response
|
mock_litellm_completion.return_value = mock_response
|
||||||
analyzer.guardrail_llm = LLM(config=default_config)
|
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
|
||||||
analyzer.check_browsing_alignment = True
|
analyzer.check_browsing_alignment = True
|
||||||
data = [
|
data = [
|
||||||
(MessageAction(usertask), EventSource.USER),
|
(MessageAction(usertask), EventSource.USER),
|
||||||
@@ -657,7 +657,7 @@ async def test_check_fillaction(
|
|||||||
analyzer = InvariantAnalyzer(event_stream)
|
analyzer = InvariantAnalyzer(event_stream)
|
||||||
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
mock_response = {'choices': [{'message': {'content': is_harmful}}]}
|
||||||
mock_litellm_completion.return_value = mock_response
|
mock_litellm_completion.return_value = mock_response
|
||||||
analyzer.guardrail_llm = LLM(config=default_config)
|
analyzer.guardrail_llm = LLM(config=default_config, service_id='test')
|
||||||
analyzer.check_browsing_alignment = True
|
analyzer.check_browsing_alignment = True
|
||||||
data = [
|
data = [
|
||||||
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
(BrowseInteractiveAction(browser_actions=fillaction), EventSource.AGENT),
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ from litellm.exceptions import (
|
|||||||
|
|
||||||
from openhands.core.config.llm_config import LLMConfig
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.llm.llm_registry import LLMRegistry
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
|
from openhands.server.services.conversation_stats import ConversationStats
|
||||||
from openhands.server.session.session import Session
|
from openhands.server.session.session import Session
|
||||||
from openhands.storage.memory import InMemoryFileStore
|
from openhands.storage.memory import InMemoryFileStore
|
||||||
|
|
||||||
@@ -33,10 +35,28 @@ def default_llm_config():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_registry():
|
||||||
|
config = OpenHandsConfig()
|
||||||
|
return LLMRegistry(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_stats():
|
||||||
|
file_store = InMemoryFileStore({})
|
||||||
|
return ConversationStats(
|
||||||
|
file_store=file_store, conversation_id='test-conversation', user_id='test-user'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('openhands.llm.llm.litellm_completion')
|
@patch('openhands.llm.llm.litellm_completion')
|
||||||
async def test_notify_on_llm_retry(
|
async def test_notify_on_llm_retry(
|
||||||
mock_litellm_completion, mock_sio, default_llm_config
|
mock_litellm_completion,
|
||||||
|
mock_sio,
|
||||||
|
default_llm_config,
|
||||||
|
llm_registry,
|
||||||
|
conversation_stats,
|
||||||
):
|
):
|
||||||
config = OpenHandsConfig()
|
config = OpenHandsConfig()
|
||||||
config.set_llm_config(default_llm_config)
|
config.set_llm_config(default_llm_config)
|
||||||
@@ -44,6 +64,8 @@ async def test_notify_on_llm_retry(
|
|||||||
sid='..sid..',
|
sid='..sid..',
|
||||||
file_store=InMemoryFileStore({}),
|
file_store=InMemoryFileStore({}),
|
||||||
config=config,
|
config=config,
|
||||||
|
llm_registry=llm_registry,
|
||||||
|
convo_stats=conversation_stats,
|
||||||
sio=mock_sio,
|
sio=mock_sio,
|
||||||
user_id='..uid..',
|
user_id='..uid..',
|
||||||
)
|
)
|
||||||
@@ -56,12 +78,20 @@ async def test_notify_on_llm_retry(
|
|||||||
),
|
),
|
||||||
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
{'choices': [{'message': {'content': 'Retry successful'}}]},
|
||||||
]
|
]
|
||||||
llm = session._create_llm('..cls..')
|
|
||||||
|
|
||||||
llm.completion(
|
# Set the retry listener on the registry
|
||||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
llm_registry.retry_listner = session._notify_on_llm_retry
|
||||||
stream=False,
|
|
||||||
)
|
# Create an LLM through the registry
|
||||||
|
llm = llm_registry.get_llm(
|
||||||
|
service_id='test_service',
|
||||||
|
config=default_llm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm.completion(
|
||||||
|
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
assert mock_litellm_completion.call_count == 2
|
assert mock_litellm_completion.call_count == 2
|
||||||
session.queue_status_message.assert_called_once_with(
|
session.queue_status_message.assert_called_once_with(
|
||||||
|
|||||||
Reference in New Issue
Block a user