mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
fix/git-di
...
openhands-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
568477c928 |
@@ -110,4 +110,4 @@ The agent is implemented in two main files:
|
||||
2. `function_calling.py`: Tool definitions and function calling interface with:
|
||||
- Tool parameter specifications
|
||||
- Tool descriptions and examples
|
||||
- Function calling response parsing
|
||||
- Function calling response parsing
|
||||
|
||||
@@ -191,11 +191,19 @@ class AgentController:
|
||||
self,
|
||||
e: Exception,
|
||||
):
|
||||
# Reset state before changing to error state to ensure proper cleanup
|
||||
if isinstance(e, RuntimeError) and str(e) == 'Agent got stuck in a loop':
|
||||
# Reset the agent's history to allow for new messages
|
||||
self.state.history = []
|
||||
self._stuck_detector = StuckDetector(self.state)
|
||||
|
||||
await self.set_agent_state_to(AgentState.ERROR)
|
||||
if self.status_callback is not None:
|
||||
err_id = ''
|
||||
if isinstance(e, litellm.AuthenticationError):
|
||||
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
|
||||
elif isinstance(e, RuntimeError) and str(e) == 'Agent got stuck in a loop':
|
||||
err_id = 'STATUS$ERROR_AGENT_STUCK'
|
||||
self.status_callback('error', err_id, str(e))
|
||||
|
||||
async def start_step_loop(self):
|
||||
@@ -313,14 +321,20 @@ class AgentController:
|
||||
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
|
||||
)
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
# Allow transitioning from ERROR state to RUNNING state when receiving a new message
|
||||
if self.get_agent_state() == AgentState.ERROR:
|
||||
self.reset_task()
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif action.source == EventSource.AGENT and action.wait_for_response:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
|
||||
def reset_task(self) -> None:
|
||||
"""Resets the agent's task."""
|
||||
"""Resets the agent's task and state."""
|
||||
self.almost_stuck = 0
|
||||
self.agent.reset()
|
||||
# Reset the stuck detector state
|
||||
if hasattr(self, '_stuck_detector'):
|
||||
self._stuck_detector = StuckDetector(self.state)
|
||||
|
||||
async def set_agent_state_to(self, new_state: AgentState) -> None:
|
||||
"""Updates the agent's state and handles side effects. Can emit events to the event stream.
|
||||
|
||||
@@ -142,7 +142,9 @@ class SessionManager:
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
await conversation.disconnect()
|
||||
|
||||
async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
|
||||
async def init_or_join_session(
|
||||
self, sid: str, connection_id: str, session_init_data: SessionInitData
|
||||
):
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self.local_connection_id_to_session_id[connection_id] = sid
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import socketio
|
||||
|
||||
@@ -9,7 +9,6 @@ from openhands.core.config import AppConfig
|
||||
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.core.schema.config import ConfigType
|
||||
from openhands.events.action import MessageAction, NullAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation import (
|
||||
@@ -68,15 +67,28 @@ class Session:
|
||||
)
|
||||
# Extract the agent-relevant arguments from the request
|
||||
agent_cls = session_init_data.agent or self.config.default_agent
|
||||
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
|
||||
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
|
||||
self.config.security.confirmation_mode = (
|
||||
self.config.security.confirmation_mode
|
||||
if session_init_data.confirmation_mode is None
|
||||
else session_init_data.confirmation_mode
|
||||
)
|
||||
self.config.security.security_analyzer = (
|
||||
session_init_data.security_analyzer
|
||||
or self.config.security.security_analyzer
|
||||
)
|
||||
max_iterations = session_init_data.max_iterations or self.config.max_iterations
|
||||
# override default LLM config
|
||||
|
||||
default_llm_config = self.config.get_llm_config()
|
||||
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
|
||||
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
|
||||
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
|
||||
default_llm_config.model = (
|
||||
session_init_data.llm_model or default_llm_config.model
|
||||
)
|
||||
default_llm_config.api_key = (
|
||||
session_init_data.llm_api_key or default_llm_config.api_key
|
||||
)
|
||||
default_llm_config.base_url = (
|
||||
session_init_data.llm_base_url or default_llm_config.base_url
|
||||
)
|
||||
|
||||
# TODO: override other LLM config & agent config groups (#2075)
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -8,6 +6,7 @@ class SessionInitData:
|
||||
"""
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
language: str | None = None
|
||||
agent: str | None = None
|
||||
max_iterations: int | None = None
|
||||
|
||||
@@ -98,6 +98,7 @@ reportlab = "*"
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -128,6 +129,7 @@ ignore = ["D1"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
66
tests/unit/test_stuck_recovery.py
Normal file
66
tests/unit/test_stuck_recovery.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stuck_recovery():
|
||||
# Mock dependencies
|
||||
event_stream = EventStream("test_session", MagicMock())
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.name = "test_agent"
|
||||
agent.sandbox_plugins = []
|
||||
|
||||
# Mock LLM
|
||||
llm_mock = MagicMock()
|
||||
llm_mock.metrics = {}
|
||||
llm_mock.config.model = "test_model"
|
||||
llm_mock.config.base_url = "test_url"
|
||||
llm_mock.config.draft_editor = None
|
||||
agent.llm = llm_mock
|
||||
|
||||
# Create controller
|
||||
controller = AgentController(
|
||||
sid="test_session",
|
||||
event_stream=event_stream,
|
||||
agent=agent,
|
||||
max_iterations=10,
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Add repeated messages to simulate stuck state
|
||||
message = MessageAction("test message", wait_for_response=False)
|
||||
observation = ErrorObservation("test error", "test_cause")
|
||||
|
||||
# Add 4 pairs of the same message and observation to simulate a loop
|
||||
for _ in range(4):
|
||||
controller.state.history.append(message)
|
||||
controller.state.history.append(observation)
|
||||
|
||||
# Verify stuck detection
|
||||
assert controller._stuck_detector.is_stuck()
|
||||
|
||||
# Simulate error handling
|
||||
await controller._react_to_exception(RuntimeError("Agent got stuck in a loop"))
|
||||
|
||||
# Verify state is reset
|
||||
assert len(controller.state.history) == 0
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
|
||||
# Create a message and set its source directly
|
||||
new_message = MessageAction("new message", wait_for_response=False)
|
||||
new_message._source = EventSource.USER # Access private attribute for testing
|
||||
|
||||
# Verify controller can process new messages
|
||||
await controller._handle_message_action(new_message)
|
||||
assert controller.state.agent_state == AgentState.RUNNING
|
||||
Reference in New Issue
Block a user