mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Add ability to restore the cli session (optional) (#2699)
* add ability to restore the main session * add quick log * rename to cli session
This commit is contained in:
@@ -156,7 +156,7 @@ def process_instance(
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': {
|
||||
'success': test_result,
|
||||
'final_message': final_message,
|
||||
|
||||
@@ -236,7 +236,7 @@ def process_instance(
|
||||
'metadata': metadata,
|
||||
'history': histories,
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': {
|
||||
'agent_answer': agent_answer,
|
||||
'final_answer': final_ans,
|
||||
|
||||
@@ -243,7 +243,7 @@ def process_instance(
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ def process_instance(
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
return output
|
||||
|
||||
@@ -197,7 +197,7 @@ def process_instance(instance, agent_class, metadata, reset_logger: bool = True)
|
||||
for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
except Exception:
|
||||
|
||||
@@ -153,7 +153,7 @@ def process_instance(
|
||||
for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
}
|
||||
except Exception:
|
||||
logger.error('Process instance failed')
|
||||
|
||||
@@ -288,7 +288,7 @@ def process_instance(
|
||||
for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ def process_instance(
|
||||
for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
except Exception:
|
||||
|
||||
@@ -263,7 +263,7 @@ def process_instance(
|
||||
'metrics': metrics,
|
||||
'final_message': final_message,
|
||||
'messages': messages,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
except Exception:
|
||||
|
||||
@@ -100,7 +100,7 @@ def process_instance(
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': reward,
|
||||
}
|
||||
|
||||
|
||||
@@ -185,7 +185,7 @@ def process_instance(
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': task_state.success if task_state else False,
|
||||
}
|
||||
|
||||
|
||||
@@ -343,7 +343,7 @@ IMPORTANT TIPS:
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': test_result,
|
||||
}
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ def process_instance(task, agent_class, metadata, reset_logger: bool = True):
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def process_instance(
|
||||
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
|
||||
],
|
||||
'metrics': metrics,
|
||||
'error': state.error if state and state.error else None,
|
||||
'error': state.last_error if state and state.last_error else None,
|
||||
'test_result': reward,
|
||||
}
|
||||
|
||||
|
||||
@@ -74,14 +74,19 @@ class AgentController:
|
||||
self._step_lock = asyncio.Lock()
|
||||
self.id = sid
|
||||
self.agent = agent
|
||||
if initial_state is None:
|
||||
self.state = State(inputs={}, max_iterations=max_iterations)
|
||||
else:
|
||||
self.state = initial_state
|
||||
|
||||
# subscribe to the event stream
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
|
||||
)
|
||||
|
||||
# state from the previous session, state from a parent agent, or a fresh state
|
||||
self._set_initial_state(
|
||||
state=initial_state,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
self.max_budget_per_task = max_budget_per_task
|
||||
if not is_delegate:
|
||||
self.agent_task = asyncio.create_task(self._start_step_loop())
|
||||
@@ -108,9 +113,9 @@ class AgentController:
|
||||
- the string message should be user-friendly, it will be shown in the UI
|
||||
- an ErrorObservation can be sent to the LLM by the agent, with the exception message, so it can self-correct next time
|
||||
"""
|
||||
self.state.error = message
|
||||
self.state.last_error = message
|
||||
if exception:
|
||||
self.state.error += f': {exception}'
|
||||
self.state.last_error += f': {exception}'
|
||||
await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
|
||||
|
||||
async def add_history(self, action: Action, observation: Observation):
|
||||
@@ -182,7 +187,7 @@ class AgentController:
|
||||
self.agent.reset()
|
||||
|
||||
async def set_agent_state_to(self, new_state: AgentState):
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
|
||||
)
|
||||
|
||||
@@ -235,7 +240,7 @@ class AgentController:
|
||||
return
|
||||
|
||||
if self._pending_action:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
@@ -336,8 +341,15 @@ class AgentController:
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
def set_state(self, state: State):
|
||||
self.state = state
|
||||
def _set_initial_state(
|
||||
self, state: State | None, max_iterations: int = MAX_ITERATIONS
|
||||
):
|
||||
# state from the previous session, state from a parent agent, or a new state
|
||||
# note that this is called twice when restoring a previous session, first with state=None
|
||||
if state is None:
|
||||
self.state = State(inputs={}, max_iterations=max_iterations)
|
||||
else:
|
||||
self.state = state
|
||||
|
||||
def _is_stuck(self):
|
||||
# check if delegate stuck
|
||||
|
||||
@@ -34,7 +34,7 @@ class State:
|
||||
updated_info: list[tuple[Action, Observation]] = field(default_factory=list)
|
||||
inputs: dict = field(default_factory=dict)
|
||||
outputs: dict = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
last_error: str | None = None
|
||||
agent_state: AgentState = AgentState.LOADING
|
||||
resume_state: AgentState | None = None
|
||||
metrics: Metrics = Metrics()
|
||||
|
||||
@@ -154,6 +154,7 @@ class AppConfig(metaclass=Singleton):
|
||||
sandbox_timeout: The timeout for the sandbox.
|
||||
debug: Whether to enable debugging.
|
||||
enable_auto_lint: Whether to enable auto linting. This is False by default, for regular runs of the app. For evaluation, please set this to True.
|
||||
enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
|
||||
file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
|
||||
file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
|
||||
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
|
||||
@@ -195,6 +196,7 @@ class AppConfig(metaclass=Singleton):
|
||||
enable_auto_lint: bool = (
|
||||
False # once enabled, OpenDevin would lint files after editing
|
||||
)
|
||||
enable_cli_session: bool = False
|
||||
file_uploads_max_file_size_mb: int = 0
|
||||
file_uploads_restrict_file_types: bool = False
|
||||
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable, Optional, Type
|
||||
from typing import Callable, Type
|
||||
|
||||
import agenthub # noqa F401 (we import this to get the agents registered)
|
||||
from opendevin.controller import AgentController
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.state.state import State
|
||||
from opendevin.core.config import args, get_llm_config_arg
|
||||
from opendevin.core.config import args, config, get_llm_config_arg
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.events import EventSource, EventStream, EventStreamSubscriber
|
||||
@@ -33,11 +33,11 @@ def read_task_from_stdin() -> str:
|
||||
async def main(
|
||||
task_str: str = '',
|
||||
exit_on_message: bool = False,
|
||||
fake_user_response_fn: Optional[Callable[[Optional[State]], str]] = None,
|
||||
sandbox: Optional[Sandbox] = None,
|
||||
runtime_tools_config: Optional[dict] = None,
|
||||
fake_user_response_fn: Callable[[State | None], str] | None = None,
|
||||
sandbox: Sandbox | None = None,
|
||||
runtime_tools_config: dict | None = None,
|
||||
sid: str | None = None,
|
||||
) -> Optional[State]:
|
||||
) -> State | None:
|
||||
"""Main coroutine to run the agent controller with task input flexibility.
|
||||
It's only used when you launch opendevin backend directly via cmdline.
|
||||
|
||||
@@ -82,16 +82,33 @@ async def main(
|
||||
)
|
||||
llm = LLM(args.model_name)
|
||||
|
||||
# set up the agent
|
||||
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
|
||||
agent = AgentCls(llm=llm)
|
||||
|
||||
event_stream = EventStream('main' + ('_' + sid if sid else ''))
|
||||
# set up the event stream
|
||||
cli_session = 'main' + ('_' + sid if sid else '')
|
||||
event_stream = EventStream(cli_session)
|
||||
|
||||
# restore cli session if enabled
|
||||
initial_state = None
|
||||
if config.enable_cli_session:
|
||||
try:
|
||||
logger.info('Restoring agent state from cli session')
|
||||
initial_state = State.restore_from_session(cli_session)
|
||||
except Exception as e:
|
||||
print('Error restoring state', e)
|
||||
|
||||
# init controller with this initial state
|
||||
controller = AgentController(
|
||||
agent=agent,
|
||||
max_iterations=args.max_iterations,
|
||||
max_budget_per_task=args.max_budget_per_task,
|
||||
event_stream=event_stream,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
# runtime and tools
|
||||
runtime = ServerRuntime(event_stream=event_stream, sandbox=sandbox)
|
||||
runtime.init_sandbox_plugins(controller.agent.sandbox_plugins)
|
||||
runtime.init_runtime_tools(
|
||||
@@ -110,7 +127,18 @@ async def main(
|
||||
task = f.read()
|
||||
logger.info(f'Dynamic Eval task: {task}')
|
||||
|
||||
await event_stream.add_event(MessageAction(content=task), EventSource.USER)
|
||||
# start event is a MessageAction with the task, either resumed or new
|
||||
if config.enable_cli_session and initial_state is not None:
|
||||
# we're resuming the previous session
|
||||
await event_stream.add_event(
|
||||
MessageAction(
|
||||
content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
|
||||
),
|
||||
EventSource.USER,
|
||||
)
|
||||
elif initial_state is None:
|
||||
# init with the provided task
|
||||
await event_stream.add_event(MessageAction(content=task), EventSource.USER)
|
||||
|
||||
async def on_event(event: Event):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
@@ -134,6 +162,12 @@ async def main(
|
||||
]:
|
||||
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
|
||||
|
||||
# save session when we're about to close
|
||||
if config.enable_cli_session:
|
||||
end_state = controller.get_state()
|
||||
end_state.save_to_session(cli_session)
|
||||
|
||||
# close when done
|
||||
await controller.close()
|
||||
runtime.close()
|
||||
return controller.get_state()
|
||||
|
||||
@@ -111,7 +111,7 @@ class AgentSession:
|
||||
)
|
||||
try:
|
||||
agent_state = State.restore_from_session(self.sid)
|
||||
self.controller.set_state(agent_state)
|
||||
self.controller._set_initial_state(agent_state)
|
||||
logger.info(f'Restored agent state from session, sid: {self.sid}')
|
||||
except Exception as e:
|
||||
print('Error restoring state', e)
|
||||
|
||||
@@ -40,7 +40,7 @@ def test_write_simple_script():
|
||||
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
|
||||
final_state: State = asyncio.run(main(task, exit_on_message=True))
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.error is None
|
||||
assert final_state.last_error is None
|
||||
|
||||
# Verify the script file exists
|
||||
script_path = os.path.join(workspace_base, 'hello.sh')
|
||||
@@ -86,7 +86,7 @@ def test_edits():
|
||||
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
|
||||
final_state: State = asyncio.run(main(task, exit_on_message=True))
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.error is None
|
||||
assert final_state.last_error is None
|
||||
|
||||
# Verify bad.txt has been fixed
|
||||
text = """This is a stupid typo.
|
||||
@@ -112,7 +112,7 @@ def test_ipython():
|
||||
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
|
||||
final_state: State = asyncio.run(main(task, exit_on_message=True))
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.error is None
|
||||
assert final_state.last_error is None
|
||||
|
||||
# Verify the file exists
|
||||
file_path = os.path.join(workspace_base, 'test.txt')
|
||||
@@ -140,7 +140,7 @@ def test_simple_task_rejection():
|
||||
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
|
||||
final_state: State = asyncio.run(main(task))
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.error is None
|
||||
assert final_state.last_error is None
|
||||
assert isinstance(final_state.history[-1][0], AgentRejectAction)
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ def test_ipython_module():
|
||||
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
|
||||
final_state: State = asyncio.run(main(task, exit_on_message=True))
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.error is None
|
||||
assert final_state.last_error is None
|
||||
|
||||
# Verify the file exists
|
||||
file_path = os.path.join(workspace_base, 'test.txt')
|
||||
@@ -186,6 +186,6 @@ def test_browse_internet(http_server):
|
||||
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
|
||||
final_state: State = asyncio.run(main(task, exit_on_message=True))
|
||||
assert final_state.agent_state == AgentState.STOPPED
|
||||
assert final_state.error is None
|
||||
assert final_state.last_error is None
|
||||
assert isinstance(final_state.history[-1][0], AgentFinishAction)
|
||||
assert 'OpenDevin is all you need!' in str(final_state.history)
|
||||
|
||||
Reference in New Issue
Block a user