Allow attaching to existing sessions without reinitializing the runtime (#4329)

Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
Robert Brennan
2024-10-14 11:24:29 -04:00
committed by GitHub
parent 640ce0f60d
commit 63ff69fd97
11 changed files with 127 additions and 143 deletions

View File

@@ -25,7 +25,6 @@ from fastapi import (
FastAPI,
HTTPException,
Request,
Response,
UploadFile,
WebSocket,
status,
@@ -40,7 +39,6 @@ import openhands.agenthub # noqa F401 (we import this to get the agents registe
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig, load_app_config
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState # Add this import
from openhands.events.action import (
ChangeAgentStateAction,
FileReadAction,
@@ -213,8 +211,10 @@ async def attach_session(request: Request, call_next):
content={'error': 'Invalid token'},
)
request.state.session = session_manager.get_session(request.state.sid)
if request.state.session is None:
request.state.conversation = session_manager.attach_to_conversation(
request.state.sid
)
if request.state.conversation is None:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
@@ -434,12 +434,13 @@ async def list_files(request: Request, path: str | None = None):
Raises:
HTTPException: If there's an error listing the files.
"""
if not request.state.session.agent_session.runtime:
if not request.state.conversation.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
file_list = await sync_from_async(runtime.list_files, path)
if path:
file_list = [os.path.join(path, f) for f in file_list]
@@ -485,7 +486,7 @@ async def select_file(file: str, request: Request):
Raises:
HTTPException: If there's an error opening the file.
"""
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
@@ -567,7 +568,7 @@ async def upload_file(request: Request, files: list[UploadFile]):
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
runtime.copy_to(
tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
)
@@ -635,35 +636,6 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel):
)
@app.get('/api/root_task')
def get_root_task(request: Request):
"""Retrieve the root task of the current agent session.
To get the root_task:
```sh
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
```
Args:
request (Request): The incoming request object.
Returns:
dict: The root task data if available.
Raises:
HTTPException: If the root task is not available.
"""
controller = request.state.session.agent_session.controller
if controller is not None:
state = controller.get_state()
if state:
return JSONResponse(
status_code=status.HTTP_200_OK,
content=state.root_task.to_dict(),
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@app.get('/api/defaults')
async def appconfig_defaults():
"""Retrieve the default configuration settings.
@@ -700,22 +672,6 @@ async def save_file(request: Request):
- 500 error if there's an unexpected error during the save operation.
"""
try:
# Get the agent's current state
controller = request.state.session.agent_session.controller
agent_state = controller.get_agent_state()
# Check if the agent is in an allowed state for editing
if agent_state not in [
AgentState.INIT,
AgentState.PAUSED,
AgentState.FINISHED,
AgentState.AWAITING_USER_INPUT,
]:
raise HTTPException(
status_code=403,
detail='Code editing is only allowed when the agent is paused, finished, or awaiting user input',
)
# Extract file path and content from the request
data = await request.json()
file_path = data.get('filePath')
@@ -726,7 +682,7 @@ async def save_file(request: Request):
raise HTTPException(status_code=400, detail='Missing filePath or content')
# Save the file to the agent's runtime file store
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
file_path = os.path.join(
runtime.config.workspace_mount_path_in_sandbox, file_path
)
@@ -768,13 +724,11 @@ async def security_api(request: Request):
Raises:
HTTPException: If the security analyzer is not initialized.
"""
if not request.state.session.agent_session.security_analyzer:
if not request.state.conversation.security_analyzer:
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
return (
await request.state.session.agent_session.security_analyzer.handle_api_request(
request
)
return await request.state.conversation.security_analyzer.handle_api_request(
request
)
@@ -782,7 +736,7 @@ async def security_api(request: Request):
async def zip_current_workspace(request: Request):
try:
logger.info('Zipping workspace')
runtime: Runtime = request.state.session.agent_session.runtime
runtime: Runtime = request.state.conversation.runtime
path = runtime.config.workspace_mount_path_in_sandbox
zip_file_bytes = runtime.copy_from(path)

View File

@@ -0,0 +1,36 @@
from openhands.core.config import AppConfig
from openhands.events.stream import EventStream
from openhands.runtime import get_runtime_cls
from openhands.runtime.runtime import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage.files import FileStore
class Conversation:
sid: str
file_store: FileStore
event_stream: EventStream
runtime: Runtime
def __init__(
self,
sid: str,
file_store: FileStore,
config: AppConfig,
):
self.sid = sid
self.config = config
self.file_store = file_store
self.event_stream = EventStream(sid, file_store)
if config.security.security_analyzer:
self.security_analyzer = options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer
)(self.event_stream)
runtime_cls = get_runtime_cls(self.config.runtime)
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
sid=self.sid,
attach_to_existing=True,
)

View File

@@ -6,7 +6,9 @@ from fastapi import WebSocket
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import session_exists
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import Session
from openhands.storage.files import FileStore
@@ -44,6 +46,11 @@ class SessionManager:
return None
return self._sessions.get(sid)
def attach_to_conversation(self, sid: str) -> Conversation | None:
if not session_exists(sid, self.file_store):
return None
return Conversation(sid, file_store=self.file_store, config=self.config)
async def send(self, sid: str, data: dict[str, object]) -> bool:
"""Sends data to the client."""
session = self.get_session(sid)