From c3d60b31d1bc6a30e772121a875516ca73e01ca1 Mon Sep 17 00:00:00 2001 From: chuckbutkus Date: Wed, 19 Mar 2025 16:03:09 -0400 Subject: [PATCH] All-1465 Move user conversations (#7340) --- openhands/controller/state/state.py | 28 ++++++++-- openhands/core/main.py | 4 +- openhands/events/stream.py | 54 ++++++++++++++----- .../conversation_manager.py | 4 +- .../standalone_conversation_manager.py | 18 ++++--- openhands/server/middleware.py | 4 +- openhands/server/session/agent_session.py | 9 ++-- openhands/server/session/conversation.py | 9 ++-- openhands/storage/locations.py | 29 +++++----- 9 files changed, 114 insertions(+), 45 deletions(-) diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index 58187d1231..549124fc0e 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -102,22 +102,42 @@ class State: extra_data: dict[str, Any] = field(default_factory=dict) last_error: str = '' - def save_to_session(self, sid: str, file_store: FileStore): + def save_to_session(self, sid: str, file_store: FileStore, user_id: str | None): pickled = pickle.dumps(self) logger.debug(f'Saving state to session {sid}:{self.agent_state}') encoded = base64.b64encode(pickled).decode('utf-8') try: - file_store.write(get_conversation_agent_state_filename(sid), encoded) + file_store.write( + get_conversation_agent_state_filename(sid, user_id), encoded + ) + + # see if state is in old directory. If yes, delete it. + filename = get_conversation_agent_state_filename(sid) + try: + file_store.delete(filename) + except Exception: + pass except Exception as e: logger.error(f'Failed to save state to session: {e}') raise e @staticmethod - def restore_from_session(sid: str, file_store: FileStore) -> 'State': + def restore_from_session( + sid: str, file_store: FileStore, user_id: str | None = None + ) -> 'State': try: - encoded = file_store.read(get_conversation_agent_state_filename(sid)) + encoded = file_store.read( + get_conversation_agent_state_filename(sid, user_id) + ) pickled = base64.b64decode(encoded) state = pickle.loads(pickled) + except FileNotFoundError: + if user_id: + # see if state is in old directory. If yes, load it. + filename = get_conversation_agent_state_filename(sid) + encoded = file_store.read(filename) + pickled = base64.b64decode(encoded) + state = pickle.loads(pickled) except Exception as e: logger.debug(f'Could not restore state from session: {e}') raise e diff --git a/openhands/core/main.py b/openhands/core/main.py index 01c3819766..fcd8a5468e 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -194,7 +194,9 @@ async def run_controller( if config.file_store is not None and config.file_store != 'memory': end_state = controller.get_state() # NOTE: the saved state does not include delegates events - end_state.save_to_session(event_stream.sid, event_stream.file_store) + end_state.save_to_session( + event_stream.sid, event_stream.file_store, event_stream.user_id + ) await controller.close(set_stop_state=False) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index e1ecb7adbe..9cdaf60b48 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -32,9 +32,11 @@ class EventStreamSubscriber(str, Enum): TEST = 'test' -async def session_exists(sid: str, file_store: FileStore) -> bool: +async def session_exists( + sid: str, file_store: FileStore, user_id: str | None = None +) -> bool: try: - await call_sync_from_async(file_store.list, get_conversation_dir(sid)) + await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id)) return True except FileNotFoundError: return False @@ -57,6 +59,7 @@ class AsyncEventStreamWrapper: class EventStream: sid: str + user_id: str | None file_store: FileStore secrets: dict[str, str] # For each subscriber ID, there is a map of callback functions - useful @@ -70,9 +73,10 @@ class EventStream: _thread_pools: dict[str, dict[str, ThreadPoolExecutor]] _thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] - def __init__(self, sid: str, file_store: FileStore): + def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None): self.sid = sid self.file_store = file_store + self.user_id = user_id self._stop_flag = threading.Event() self._queue: queue.Queue[Event] = queue.Queue() self._thread_pools = {} @@ -90,10 +94,24 @@ class EventStream: self.__post_init__() def __post_init__(self) -> None: + events = [] + try: - events = self.file_store.list(get_conversation_events_dir(self.sid)) + events_dir = get_conversation_events_dir(self.sid, self.user_id) + events += self.file_store.list(events_dir) except FileNotFoundError: - logger.debug(f'No events found for session {self.sid}') + logger.debug(f'No events found for session {self.sid} at {events_dir}') + + if self.user_id: + # During transition to new location, try old location if user_id is set + # TODO: remove this code after 5/1/2025 + try: + events_dir = get_conversation_events_dir(self.sid) + events += self.file_store.list(events_dir) + except FileNotFoundError: + logger.debug(f'No events found for session {self.sid} at {events_dir}') + + if not events: self._cur_id = 0 return @@ -156,8 +174,8 @@ class EventStream: del self._subscribers[subscriber_id][callback_id] - def _get_filename_for_id(self, id: int) -> str: - return get_conversation_event_filename(self.sid, id) + def _get_filename_for_id(self, id: int, user_id: str | None) -> str: + return get_conversation_event_filename(self.sid, id, user_id) @staticmethod def _get_id_from_filename(filename: str) -> int: @@ -223,10 +241,20 @@ class EventStream: event_id += 1 def get_event(self, id: int) -> Event: - filename = self._get_filename_for_id(id) - content = self.file_store.read(filename) - data = json.loads(content) - return event_from_dict(data) + filename = self._get_filename_for_id(id, self.user_id) + try: + content = self.file_store.read(filename) + data = json.loads(content) + return event_from_dict(data) + except FileNotFoundError: + logger.debug(f'File {filename} not found') + # TODO remove this block after 5/1/2025 + if self.user_id: + filename = self._get_filename_for_id(id, None) + content = self.file_store.read(filename) + data = json.loads(content) + return event_from_dict(data) + raise def get_latest_event(self) -> Event: return self.get_event(self._cur_id - 1) @@ -277,7 +305,9 @@ class EventStream: data = self._replace_secrets(data) event = event_from_dict(data) if event.id is not None: - self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data)) + self.file_store.write( + self._get_filename_for_id(event.id, self.user_id), json.dumps(data) + ) self._queue.put(event) def set_secrets(self, secrets: dict[str, str]): diff --git a/openhands/server/conversation_manager/conversation_manager.py b/openhands/server/conversation_manager/conversation_manager.py index d152a936f6..86bae39970 100644 --- a/openhands/server/conversation_manager/conversation_manager.py +++ b/openhands/server/conversation_manager/conversation_manager.py @@ -37,7 +37,9 @@ class ConversationManager(ABC): """Clean up the conversation manager.""" @abstractmethod - async def attach_to_conversation(self, sid: str) -> Conversation | None: + async def attach_to_conversation( + self, sid: str, user_id: str | None = None + ) -> Conversation | None: """Attach to an existing conversation or create a new one.""" @abstractmethod diff --git a/openhands/server/conversation_manager/standalone_conversation_manager.py b/openhands/server/conversation_manager/standalone_conversation_manager.py index b36f4025f9..540228a518 100644 --- a/openhands/server/conversation_manager/standalone_conversation_manager.py +++ b/openhands/server/conversation_manager/standalone_conversation_manager.py @@ -63,9 +63,11 @@ class StandaloneConversationManager(ConversationManager): self._cleanup_task.cancel() self._cleanup_task = None - async def attach_to_conversation(self, sid: str) -> Conversation | None: + async def attach_to_conversation( + self, sid: str, user_id: str | None = None + ) -> Conversation | None: start_time = time.time() - if not await session_exists(sid, self.file_store): + if not await session_exists(sid, self.file_store, user_id=user_id): return None async with self._conversations_lock: @@ -88,7 +90,9 @@ class StandaloneConversationManager(ConversationManager): return conversation # Create new conversation if none exists - c = Conversation(sid, file_store=self.file_store, config=self.config) + c = Conversation( + sid, file_store=self.file_store, config=self.config, user_id=user_id + ) try: await c.connect() except AgentRuntimeUnavailableError as e: @@ -119,7 +123,7 @@ class StandaloneConversationManager(ConversationManager): ) await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) self._local_connection_id_to_session_id[connection_id] = sid - event_stream = await self._get_event_stream(sid) + event_stream = await self._get_event_stream(sid, user_id) if not event_stream: return await self.maybe_start_agent_loop( sid, settings, user_id, github_user_id=github_user_id @@ -299,7 +303,7 @@ class StandaloneConversationManager(ConversationManager): except ValueError: pass # Already subscribed - take no action - event_stream = await self._get_event_stream(sid) + event_stream = await self._get_event_stream(sid, user_id) if not event_stream: logger.error( f'No event stream after starting agent loop: {sid}', @@ -308,7 +312,9 @@ class StandaloneConversationManager(ConversationManager): raise RuntimeError(f'no_event_stream:{sid}') return event_stream - async def _get_event_stream(self, sid: str) -> EventStream | None: + async def _get_event_stream( + self, sid: str, user_id: str | None + ) -> EventStream | None: logger.info(f'_get_event_stream:{sid}', extra={'session_id': sid}) session = self._local_agent_loops_by_sid.get(sid) if session: diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 3d690306f7..acb09c287b 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -148,7 +148,9 @@ class AttachConversationMiddleware(SessionMiddlewareInterface): Attach the user's session based on the provided authentication token. """ request.state.conversation = ( - await shared.conversation_manager.attach_to_conversation(request.state.sid) + await shared.conversation_manager.attach_to_conversation( + request.state.sid, get_user_id(request) + ) ) if not request.state.conversation: return JSONResponse( diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 58518ab7a3..5c9c1ca498 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -37,6 +37,7 @@ class AgentSession: """ sid: str + user_id: str | None event_stream: EventStream file_store: FileStore controller: AgentController | None = None @@ -63,7 +64,7 @@ class AgentSession: """ self.sid = sid - self.event_stream = EventStream(sid, file_store) + self.event_stream = EventStream(sid, file_store, user_id) self.file_store = file_store self._status_callback = status_callback self.user_id = user_id @@ -186,7 +187,7 @@ class AgentSession: self.event_stream.close() if self.controller is not None: end_state = self.controller.get_state() - end_state.save_to_session(self.sid, self.file_store) + end_state.save_to_session(self.sid, self.file_store, self.user_id) await self.controller.close() if self.runtime is not None: self.runtime.close() @@ -371,7 +372,9 @@ class AgentSession: # Use a heuristic to figure out if we should have a state: # if we have events in the stream. try: - restored_state = State.restore_from_session(self.sid, self.file_store) + restored_state = State.restore_from_session( + self.sid, self.file_store, self.user_id + ) self.logger.debug(f'Restored state from session, sid: {self.sid}') except Exception as e: if self.event_stream.get_latest_event_id() > 0: diff --git a/openhands/server/session/conversation.py b/openhands/server/session/conversation.py index 14aa1363e6..b827996e4c 100644 --- a/openhands/server/session/conversation.py +++ b/openhands/server/session/conversation.py @@ -14,17 +14,16 @@ class Conversation: file_store: FileStore event_stream: EventStream runtime: Runtime + user_id: str | None def __init__( - self, - sid: str, - file_store: FileStore, - config: AppConfig, + self, sid: str, file_store: FileStore, config: AppConfig, user_id: str | None ): self.sid = sid self.config = config self.file_store = file_store - self.event_stream = EventStream(sid, file_store) + self.user_id = user_id + self.event_stream = EventStream(sid, file_store, user_id) if config.security.security_analyzer: self.security_analyzer = options.SecurityAnalyzers.get( config.security.security_analyzer, SecurityAnalyzer diff --git a/openhands/storage/locations.py b/openhands/storage/locations.py index 450ebe4a6f..43e1661b91 100644 --- a/openhands/storage/locations.py +++ b/openhands/storage/locations.py @@ -1,25 +1,30 @@ CONVERSATION_BASE_DIR = 'sessions' -def get_conversation_dir(sid: str) -> str: - return f'{CONVERSATION_BASE_DIR}/{sid}/' +def get_conversation_dir(sid: str, user_id: str | None = None) -> str: + if user_id: + return f'users/{user_id}/conversations/{sid}/' + else: + return f'{CONVERSATION_BASE_DIR}/{sid}/' -def get_conversation_events_dir(sid: str) -> str: - return f'{get_conversation_dir(sid)}events/' +def get_conversation_events_dir(sid: str, user_id: str | None = None) -> str: + return f'{get_conversation_dir(sid, user_id)}events/' -def get_conversation_event_filename(sid: str, id: int) -> str: - return f'{get_conversation_events_dir(sid)}{id}.json' +def get_conversation_event_filename( + sid: str, id: int, user_id: str | None = None +) -> str: + return f'{get_conversation_events_dir(sid, user_id)}{id}.json' -def get_conversation_metadata_filename(sid: str) -> str: - return f'{get_conversation_dir(sid)}metadata.json' +def get_conversation_metadata_filename(sid: str, user_id: str | None = None) -> str: + return f'{get_conversation_dir(sid, user_id)}metadata.json' -def get_conversation_init_data_filename(sid: str) -> str: - return f'{get_conversation_dir(sid)}init.json' +def get_conversation_init_data_filename(sid: str, user_id: str | None = None) -> str: + return f'{get_conversation_dir(sid, user_id)}init.json' -def get_conversation_agent_state_filename(sid: str) -> str: - return f'{get_conversation_dir(sid)}agent_state.pkl' +def get_conversation_agent_state_filename(sid: str, user_id: str | None = None) -> str: + return f'{get_conversation_dir(sid, user_id)}agent_state.pkl'