Compare commits

...

1 Commits

Author SHA1 Message Date
openhands
568477c928 Fix issue #5480: [Bug]: Cannot recover from "Agent stuck in loop" 2024-12-09 19:55:08 +00:00
7 changed files with 107 additions and 12 deletions

View File

@@ -110,4 +110,4 @@ The agent is implemented in two main files:
2. `function_calling.py`: Tool definitions and function calling interface with:
- Tool parameter specifications
- Tool descriptions and examples
- Function calling response parsing
- Function calling response parsing

View File

@@ -191,11 +191,19 @@ class AgentController:
self,
e: Exception,
):
# Reset state before changing to error state to ensure proper cleanup
if isinstance(e, RuntimeError) and str(e) == 'Agent got stuck in a loop':
# Reset the agent's history to allow for new messages
self.state.history = []
self._stuck_detector = StuckDetector(self.state)
await self.set_agent_state_to(AgentState.ERROR)
if self.status_callback is not None:
err_id = ''
if isinstance(e, litellm.AuthenticationError):
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
elif isinstance(e, RuntimeError) and str(e) == 'Agent got stuck in a loop':
err_id = 'STATUS$ERROR_AGENT_STUCK'
self.status_callback('error', err_id, str(e))
async def start_step_loop(self):
@@ -313,14 +321,20 @@ class AgentController:
extra={'msg_type': 'ACTION', 'event_source': EventSource.USER},
)
if self.get_agent_state() != AgentState.RUNNING:
# Allow transitioning from ERROR state to RUNNING state when receiving a new message
if self.get_agent_state() == AgentState.ERROR:
self.reset_task()
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
def reset_task(self) -> None:
"""Resets the agent's task."""
"""Resets the agent's task and state."""
self.almost_stuck = 0
self.agent.reset()
# Reset the stuck detector state
if hasattr(self, '_stuck_detector'):
self._stuck_detector = StuckDetector(self.state)
async def set_agent_state_to(self, new_state: AgentState) -> None:
"""Updates the agent's state and handles side effects. Can emit events to the event stream.

View File

@@ -142,7 +142,9 @@ class SessionManager:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
async def init_or_join_session(
self, sid: str, connection_id: str, session_init_data: SessionInitData
):
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid

View File

@@ -1,6 +1,6 @@
import asyncio
from copy import deepcopy
import time
from copy import deepcopy
import socketio
@@ -9,7 +9,6 @@ from openhands.core.config import AppConfig
from openhands.core.const.guide_url import TROUBLESHOOTING_URL
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.core.schema.config import ConfigType
from openhands.events.action import MessageAction, NullAction
from openhands.events.event import Event, EventSource
from openhands.events.observation import (
@@ -68,15 +67,28 @@ class Session:
)
# Extract the agent-relevant arguments from the request
agent_cls = session_init_data.agent or self.config.default_agent
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
self.config.security.confirmation_mode = (
self.config.security.confirmation_mode
if session_init_data.confirmation_mode is None
else session_init_data.confirmation_mode
)
self.config.security.security_analyzer = (
session_init_data.security_analyzer
or self.config.security.security_analyzer
)
max_iterations = session_init_data.max_iterations or self.config.max_iterations
# override default LLM config
default_llm_config = self.config.get_llm_config()
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
default_llm_config.model = (
session_init_data.llm_model or default_llm_config.model
)
default_llm_config.api_key = (
session_init_data.llm_api_key or default_llm_config.api_key
)
default_llm_config.base_url = (
session_init_data.llm_base_url or default_llm_config.base_url
)
# TODO: override other LLM config & agent config groups (#2075)

View File

@@ -1,5 +1,3 @@
from dataclasses import dataclass
@@ -8,6 +6,7 @@ class SessionInitData:
"""
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
language: str | None = None
agent: str | None = None
max_iterations: int | None = None

View File

@@ -98,6 +98,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -128,6 +129,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"

View File

@@ -0,0 +1,66 @@
import asyncio
import pytest
from unittest.mock import MagicMock
from openhands.controller.agent_controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.schema.agent import AgentState
from openhands.events.stream import EventStream
from openhands.events.action import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation.error import ErrorObservation
@pytest.mark.asyncio
async def test_stuck_recovery():
# Mock dependencies
event_stream = EventStream("test_session", MagicMock())
agent = MagicMock(spec=Agent)
agent.name = "test_agent"
agent.sandbox_plugins = []
# Mock LLM
llm_mock = MagicMock()
llm_mock.metrics = {}
llm_mock.config.model = "test_model"
llm_mock.config.base_url = "test_url"
llm_mock.config.draft_editor = None
agent.llm = llm_mock
# Create controller
controller = AgentController(
sid="test_session",
event_stream=event_stream,
agent=agent,
max_iterations=10,
confirmation_mode=False,
headless_mode=True,
)
# Add repeated messages to simulate stuck state
message = MessageAction("test message", wait_for_response=False)
observation = ErrorObservation("test error", "test_cause")
# Add 4 pairs of the same message and observation to simulate a loop
for _ in range(4):
controller.state.history.append(message)
controller.state.history.append(observation)
# Verify stuck detection
assert controller._stuck_detector.is_stuck()
# Simulate error handling
await controller._react_to_exception(RuntimeError("Agent got stuck in a loop"))
# Verify state is reset
assert len(controller.state.history) == 0
assert controller.state.agent_state == AgentState.ERROR
# Create a message and set its source directly
new_message = MessageAction("new message", wait_for_response=False)
new_message._source = EventSource.USER # Access private attribute for testing
# Verify controller can process new messages
await controller._handle_message_action(new_message)
assert controller.state.agent_state == AgentState.RUNNING