Compare commits

...

2 Commits

Author SHA1 Message Date
Robert Brennan
dd80f1a0c8 fix check for convo 2024-12-26 17:33:17 -05:00
Robert Brennan
3a2892b5b6 better state restore logic 2024-12-26 15:05:42 -05:00
5 changed files with 32 additions and 16 deletions

View File

@@ -14,6 +14,7 @@ from openhands.events.action.agent import AgentFinishAction
from openhands.events.event import Event, EventSource
from openhands.llm.metrics import Metrics
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_sync_from_async
class TrafficControlState(str, Enum):
@@ -112,9 +113,13 @@ class State:
raise e
@staticmethod
def restore_from_session(sid: str, file_store: FileStore) -> 'State':
async def restore_from_conversation_files(
sid: str, file_store: FileStore
) -> 'State':
try:
encoded = file_store.read(f'sessions/{sid}/agent_state.pkl')
encoded = await call_sync_from_async(
file_store.read, f'sessions/{sid}/agent_state.pkl'
)
pickled = base64.b64decode(encoded)
state = pickle.loads(pickled)
except Exception as e:

View File

@@ -140,7 +140,7 @@ async def run_controller(
logger.debug(
f'Trying to restore agent state from cli session {event_stream.sid} if available'
)
initial_state = State.restore_from_session(
initial_state = await State.restore_from_conversation_files(
event_stream.sid, event_stream.file_store
)
except Exception as e:

View File

@@ -29,7 +29,7 @@ class EventStreamSubscriber(str, Enum):
TEST = 'test'
async def session_exists(sid: str, file_store: FileStore) -> bool:
async def conversation_exists(sid: str, file_store: FileStore) -> bool:
try:
await call_sync_from_async(file_store.list, get_conversation_dir(sid))
return True

View File

@@ -10,7 +10,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.events.stream import EventStream, conversation_exists
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
@@ -53,7 +53,6 @@ class AgentSession:
"""
self.sid = sid
self.event_stream = EventStream(sid, file_store)
self.file_store = file_store
self._status_callback = status_callback
@@ -117,6 +116,8 @@ class AgentSession:
github_token: str | None = None,
selected_repository: str | None = None,
):
is_existing_conversation = await conversation_exists(self.sid, self.file_store)
self.event_stream = EventStream(self.sid, self.file_store)
if self._closed:
logger.warning('Session closed before starting')
return
@@ -130,13 +131,14 @@ class AgentSession:
selected_repository=selected_repository,
)
self.controller = self._create_controller(
self.controller = await self._create_controller(
agent,
config.security.confirmation_mode,
max_iterations,
max_budget_per_task=max_budget_per_task,
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
is_existing_conversation=is_existing_conversation,
)
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
@@ -247,11 +249,12 @@ class AgentSession:
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
)
def _create_controller(
async def _create_controller(
self,
agent: Agent,
confirmation_mode: bool,
max_iterations: int,
is_existing_conversation: bool,
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
@@ -262,6 +265,7 @@ class AgentSession:
- agent:
- confirmation_mode: Whether to use confirmation mode
- max_iterations:
- is_existing_conversation:
- max_budget_per_task:
- agent_to_llm_config:
- agent_configs:
@@ -304,11 +308,18 @@ class AgentSession:
headless_mode=False,
status_callback=self._status_callback,
)
try:
agent_state = State.restore_from_session(self.sid, self.file_store)
controller.set_initial_state(agent_state, max_iterations, confirmation_mode)
logger.debug(f'Restored agent state from session, sid: {self.sid}')
except Exception as e:
logger.debug(f'State could not be restored: {e}')
if is_existing_conversation:
logger.info(f'Restoring agent state from conversation: {self.sid}')
try:
agent_state = await State.restore_from_conversation_files(
self.sid, self.file_store
)
controller.set_initial_state(
agent_state, max_iterations, confirmation_mode
)
logger.debug(f'Restored agent state from session, sid: {self.sid}')
except Exception as e:
logger.error(f'State could not be restored: {e}')
raise
logger.debug('Agent controller initialized.')
return controller

View File

@@ -8,7 +8,7 @@ import socketio
from openhands.core.config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStream, session_exists
from openhands.events.stream import EventStream, conversation_exists
from openhands.server.session.conversation import Conversation
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.session.session import ROOM_KEY, Session
@@ -146,7 +146,7 @@ class SessionManager:
async def attach_to_conversation(self, sid: str) -> Conversation | None:
start_time = time.time()
if not await session_exists(sid, self.file_store):
if not await conversation_exists(sid, self.file_store):
return None
async with self._conversations_lock: