mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Allow attaching to existing sessions without reinitializing the runtime (#4329)
Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from fastapi import (
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
WebSocket,
|
||||
status,
|
||||
@@ -40,7 +39,6 @@ import openhands.agenthub # noqa F401 (we import this to get the agents registe
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, load_app_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState # Add this import
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
FileReadAction,
|
||||
@@ -213,8 +211,10 @@ async def attach_session(request: Request, call_next):
|
||||
content={'error': 'Invalid token'},
|
||||
)
|
||||
|
||||
request.state.session = session_manager.get_session(request.state.sid)
|
||||
if request.state.session is None:
|
||||
request.state.conversation = session_manager.attach_to_conversation(
|
||||
request.state.sid
|
||||
)
|
||||
if request.state.conversation is None:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Session not found'},
|
||||
@@ -434,12 +434,13 @@ async def list_files(request: Request, path: str | None = None):
|
||||
Raises:
|
||||
HTTPException: If there's an error listing the files.
|
||||
"""
|
||||
if not request.state.session.agent_session.runtime:
|
||||
if not request.state.conversation.runtime:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'Runtime not yet initialized'},
|
||||
)
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
file_list = await sync_from_async(runtime.list_files, path)
|
||||
if path:
|
||||
file_list = [os.path.join(path, f) for f in file_list]
|
||||
@@ -485,7 +486,7 @@ async def select_file(file: str, request: Request):
|
||||
Raises:
|
||||
HTTPException: If there's an error opening the file.
|
||||
"""
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
|
||||
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
|
||||
read_action = FileReadAction(file)
|
||||
@@ -567,7 +568,7 @@ async def upload_file(request: Request, files: list[UploadFile]):
|
||||
tmp_file.write(file_contents)
|
||||
tmp_file.flush()
|
||||
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
runtime.copy_to(
|
||||
tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
|
||||
)
|
||||
@@ -635,35 +636,6 @@ async def submit_feedback(request: Request, feedback: FeedbackDataModel):
|
||||
)
|
||||
|
||||
|
||||
@app.get('/api/root_task')
|
||||
def get_root_task(request: Request):
|
||||
"""Retrieve the root task of the current agent session.
|
||||
|
||||
To get the root_task:
|
||||
```sh
|
||||
curl -H "Authorization: Bearer <TOKEN>" http://localhost:3000/api/root_task
|
||||
```
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request object.
|
||||
|
||||
Returns:
|
||||
dict: The root task data if available.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the root task is not available.
|
||||
"""
|
||||
controller = request.state.session.agent_session.controller
|
||||
if controller is not None:
|
||||
state = controller.get_state()
|
||||
if state:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=state.root_task.to_dict(),
|
||||
)
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
@app.get('/api/defaults')
|
||||
async def appconfig_defaults():
|
||||
"""Retrieve the default configuration settings.
|
||||
@@ -700,22 +672,6 @@ async def save_file(request: Request):
|
||||
- 500 error if there's an unexpected error during the save operation.
|
||||
"""
|
||||
try:
|
||||
# Get the agent's current state
|
||||
controller = request.state.session.agent_session.controller
|
||||
agent_state = controller.get_agent_state()
|
||||
|
||||
# Check if the agent is in an allowed state for editing
|
||||
if agent_state not in [
|
||||
AgentState.INIT,
|
||||
AgentState.PAUSED,
|
||||
AgentState.FINISHED,
|
||||
AgentState.AWAITING_USER_INPUT,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail='Code editing is only allowed when the agent is paused, finished, or awaiting user input',
|
||||
)
|
||||
|
||||
# Extract file path and content from the request
|
||||
data = await request.json()
|
||||
file_path = data.get('filePath')
|
||||
@@ -726,7 +682,7 @@ async def save_file(request: Request):
|
||||
raise HTTPException(status_code=400, detail='Missing filePath or content')
|
||||
|
||||
# Save the file to the agent's runtime file store
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
file_path = os.path.join(
|
||||
runtime.config.workspace_mount_path_in_sandbox, file_path
|
||||
)
|
||||
@@ -768,13 +724,11 @@ async def security_api(request: Request):
|
||||
Raises:
|
||||
HTTPException: If the security analyzer is not initialized.
|
||||
"""
|
||||
if not request.state.session.agent_session.security_analyzer:
|
||||
if not request.state.conversation.security_analyzer:
|
||||
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
|
||||
|
||||
return (
|
||||
await request.state.session.agent_session.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
return await request.state.conversation.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
|
||||
|
||||
@@ -782,7 +736,7 @@ async def security_api(request: Request):
|
||||
async def zip_current_workspace(request: Request):
|
||||
try:
|
||||
logger.info('Zipping workspace')
|
||||
runtime: Runtime = request.state.session.agent_session.runtime
|
||||
runtime: Runtime = request.state.conversation.runtime
|
||||
|
||||
path = runtime.config.workspace_mount_path_in_sandbox
|
||||
zip_file_bytes = runtime.copy_from(path)
|
||||
|
||||
36
openhands/server/session/conversation.py
Normal file
36
openhands/server/session/conversation.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.security import SecurityAnalyzer, options
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class Conversation:
|
||||
sid: str
|
||||
file_store: FileStore
|
||||
event_stream: EventStream
|
||||
runtime: Runtime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
config: AppConfig,
|
||||
):
|
||||
self.sid = sid
|
||||
self.config = config
|
||||
self.file_store = file_store
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
if config.security.security_analyzer:
|
||||
self.security_analyzer = options.SecurityAnalyzers.get(
|
||||
config.security.security_analyzer, SecurityAnalyzer
|
||||
)(self.event_stream)
|
||||
|
||||
runtime_cls = get_runtime_cls(self.config.runtime)
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
attach_to_existing=True,
|
||||
)
|
||||
@@ -6,7 +6,9 @@ from fastapi import WebSocket
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import session_exists
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
@@ -44,6 +46,11 @@ class SessionManager:
|
||||
return None
|
||||
return self._sessions.get(sid)
|
||||
|
||||
def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
if not session_exists(sid, self.file_store):
|
||||
return None
|
||||
return Conversation(sid, file_store=self.file_store, config=self.config)
|
||||
|
||||
async def send(self, sid: str, data: dict[str, object]) -> bool:
|
||||
"""Sends data to the client."""
|
||||
session = self.get_session(sid)
|
||||
|
||||
Reference in New Issue
Block a user