Add ability to restore the cli session (optional) (#2699)

* add ability to restore the main session

* add quick log

* rename to cli session
This commit is contained in:
Engel Nyst
2024-06-30 08:56:55 +02:00
committed by GitHub
parent 874b4c9075
commit 2d9bb56763
20 changed files with 88 additions and 40 deletions

View File

@@ -156,7 +156,7 @@ def process_instance(
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': {
'success': test_result,
'final_message': final_message,

View File

@@ -236,7 +236,7 @@ def process_instance(
'metadata': metadata,
'history': histories,
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': {
'agent_answer': agent_answer,
'final_answer': final_ans,

View File

@@ -243,7 +243,7 @@ def process_instance(
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}

View File

@@ -245,7 +245,7 @@ def process_instance(
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}
return output

View File

@@ -197,7 +197,7 @@ def process_instance(instance, agent_class, metadata, reset_logger: bool = True)
for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}
except Exception:

View File

@@ -153,7 +153,7 @@ def process_instance(
for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
}
except Exception:
logger.error('Process instance failed')

View File

@@ -288,7 +288,7 @@ def process_instance(
for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}

View File

@@ -241,7 +241,7 @@ def process_instance(
for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}
except Exception:

View File

@@ -263,7 +263,7 @@ def process_instance(
'metrics': metrics,
'final_message': final_message,
'messages': messages,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}
except Exception:

View File

@@ -100,7 +100,7 @@ def process_instance(
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': reward,
}

View File

@@ -185,7 +185,7 @@ def process_instance(
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': task_state.success if task_state else False,
}

View File

@@ -343,7 +343,7 @@ IMPORTANT TIPS:
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': test_result,
}

View File

@@ -139,7 +139,7 @@ def process_instance(task, agent_class, metadata, reset_logger: bool = True):
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
}
return output

View File

@@ -100,7 +100,7 @@ def process_instance(
(event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
],
'metrics': metrics,
'error': state.error if state and state.error else None,
'error': state.last_error if state and state.last_error else None,
'test_result': reward,
}

View File

@@ -74,14 +74,19 @@ class AgentController:
self._step_lock = asyncio.Lock()
self.id = sid
self.agent = agent
if initial_state is None:
self.state = State(inputs={}, max_iterations=max_iterations)
else:
self.state = initial_state
# subscribe to the event stream
self.event_stream = event_stream
self.event_stream.subscribe(
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
)
# state from the previous session, state from a parent agent, or a fresh state
self._set_initial_state(
state=initial_state,
max_iterations=max_iterations,
)
self.max_budget_per_task = max_budget_per_task
if not is_delegate:
self.agent_task = asyncio.create_task(self._start_step_loop())
@@ -108,9 +113,9 @@ class AgentController:
- the string message should be user-friendly, it will be shown in the UI
- an ErrorObservation can be sent to the LLM by the agent, with the exception message, so it can self-correct next time
"""
self.state.error = message
self.state.last_error = message
if exception:
self.state.error += f': {exception}'
self.state.last_error += f': {exception}'
await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
async def add_history(self, action: Action, observation: Observation):
@@ -182,7 +187,7 @@ class AgentController:
self.agent.reset()
async def set_agent_state_to(self, new_state: AgentState):
logger.info(
logger.debug(
f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
)
@@ -235,7 +240,7 @@ class AgentController:
return
if self._pending_action:
logger.info(
logger.debug(
f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
)
await asyncio.sleep(1)
@@ -336,8 +341,15 @@ class AgentController:
def get_state(self):
return self.state
def set_state(self, state: State):
self.state = state
def _set_initial_state(
self, state: State | None, max_iterations: int = MAX_ITERATIONS
):
# state from the previous session, state from a parent agent, or a new state
# note that this is called twice when restoring a previous session, first with state=None
if state is None:
self.state = State(inputs={}, max_iterations=max_iterations)
else:
self.state = state
def _is_stuck(self):
# check if delegate stuck

View File

@@ -34,7 +34,7 @@ class State:
updated_info: list[tuple[Action, Observation]] = field(default_factory=list)
inputs: dict = field(default_factory=dict)
outputs: dict = field(default_factory=dict)
error: str | None = None
last_error: str | None = None
agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None
metrics: Metrics = Metrics()

View File

@@ -154,6 +154,7 @@ class AppConfig(metaclass=Singleton):
sandbox_timeout: The timeout for the sandbox.
debug: Whether to enable debugging.
enable_auto_lint: Whether to enable auto linting. This is False by default, for regular runs of the app. For evaluation, please set this to True.
enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
@@ -195,6 +196,7 @@ class AppConfig(metaclass=Singleton):
enable_auto_lint: bool = (
False # once enabled, OpenDevin would lint files after editing
)
enable_cli_session: bool = False
file_uploads_max_file_size_mb: int = 0
file_uploads_restrict_file_types: bool = False
file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])

View File

@@ -1,13 +1,13 @@
import asyncio
import os
import sys
from typing import Callable, Optional, Type
from typing import Callable, Type
import agenthub # noqa F401 (we import this to get the agents registered)
from opendevin.controller import AgentController
from opendevin.controller.agent import Agent
from opendevin.controller.state.state import State
from opendevin.core.config import args, get_llm_config_arg
from opendevin.core.config import args, config, get_llm_config_arg
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.schema import AgentState
from opendevin.events import EventSource, EventStream, EventStreamSubscriber
@@ -33,11 +33,11 @@ def read_task_from_stdin() -> str:
async def main(
task_str: str = '',
exit_on_message: bool = False,
fake_user_response_fn: Optional[Callable[[Optional[State]], str]] = None,
sandbox: Optional[Sandbox] = None,
runtime_tools_config: Optional[dict] = None,
fake_user_response_fn: Callable[[State | None], str] | None = None,
sandbox: Sandbox | None = None,
runtime_tools_config: dict | None = None,
sid: str | None = None,
) -> Optional[State]:
) -> State | None:
"""Main coroutine to run the agent controller with task input flexibility.
It's only used when you launch opendevin backend directly via cmdline.
@@ -82,16 +82,33 @@ async def main(
)
llm = LLM(args.model_name)
# set up the agent
AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
agent = AgentCls(llm=llm)
event_stream = EventStream('main' + ('_' + sid if sid else ''))
# set up the event stream
cli_session = 'main' + ('_' + sid if sid else '')
event_stream = EventStream(cli_session)
# restore cli session if enabled
initial_state = None
if config.enable_cli_session:
try:
logger.info('Restoring agent state from cli session')
initial_state = State.restore_from_session(cli_session)
except Exception as e:
print('Error restoring state', e)
# init controller with this initial state
controller = AgentController(
agent=agent,
max_iterations=args.max_iterations,
max_budget_per_task=args.max_budget_per_task,
event_stream=event_stream,
initial_state=initial_state,
)
# runtime and tools
runtime = ServerRuntime(event_stream=event_stream, sandbox=sandbox)
runtime.init_sandbox_plugins(controller.agent.sandbox_plugins)
runtime.init_runtime_tools(
@@ -110,7 +127,18 @@ async def main(
task = f.read()
logger.info(f'Dynamic Eval task: {task}')
await event_stream.add_event(MessageAction(content=task), EventSource.USER)
# start event is a MessageAction with the task, either resumed or new
if config.enable_cli_session and initial_state is not None:
# we're resuming the previous session
await event_stream.add_event(
MessageAction(
content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
),
EventSource.USER,
)
elif initial_state is None:
# init with the provided task
await event_stream.add_event(MessageAction(content=task), EventSource.USER)
async def on_event(event: Event):
if isinstance(event, AgentStateChangedObservation):
@@ -134,6 +162,12 @@ async def main(
]:
await asyncio.sleep(1) # Give back control for a tick, so the agent can run
# save session when we're about to close
if config.enable_cli_session:
end_state = controller.get_state()
end_state.save_to_session(cli_session)
# close when done
await controller.close()
runtime.close()
return controller.get_state()

View File

@@ -111,7 +111,7 @@ class AgentSession:
)
try:
agent_state = State.restore_from_session(self.sid)
self.controller.set_state(agent_state)
self.controller._set_initial_state(agent_state)
logger.info(f'Restored agent state from session, sid: {self.sid}')
except Exception as e:
print('Error restoring state', e)

View File

@@ -40,7 +40,7 @@ def test_write_simple_script():
task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
final_state: State = asyncio.run(main(task, exit_on_message=True))
assert final_state.agent_state == AgentState.STOPPED
assert final_state.error is None
assert final_state.last_error is None
# Verify the script file exists
script_path = os.path.join(workspace_base, 'hello.sh')
@@ -86,7 +86,7 @@ def test_edits():
task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
final_state: State = asyncio.run(main(task, exit_on_message=True))
assert final_state.agent_state == AgentState.STOPPED
assert final_state.error is None
assert final_state.last_error is None
# Verify bad.txt has been fixed
text = """This is a stupid typo.
@@ -112,7 +112,7 @@ def test_ipython():
task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
final_state: State = asyncio.run(main(task, exit_on_message=True))
assert final_state.agent_state == AgentState.STOPPED
assert final_state.error is None
assert final_state.last_error is None
# Verify the file exists
file_path = os.path.join(workspace_base, 'test.txt')
@@ -140,7 +140,7 @@ def test_simple_task_rejection():
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
final_state: State = asyncio.run(main(task))
assert final_state.agent_state == AgentState.STOPPED
assert final_state.error is None
assert final_state.last_error is None
assert isinstance(final_state.history[-1][0], AgentRejectAction)
@@ -157,7 +157,7 @@ def test_ipython_module():
task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
final_state: State = asyncio.run(main(task, exit_on_message=True))
assert final_state.agent_state == AgentState.STOPPED
assert final_state.error is None
assert final_state.last_error is None
# Verify the file exists
file_path = os.path.join(workspace_base, 'test.txt')
@@ -186,6 +186,6 @@ def test_browse_internet(http_server):
task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
final_state: State = asyncio.run(main(task, exit_on_message=True))
assert final_state.agent_state == AgentState.STOPPED
assert final_state.error is None
assert final_state.last_error is None
assert isinstance(final_state.history[-1][0], AgentFinishAction)
assert 'OpenDevin is all you need!' in str(final_state.history)