Compare commits

...

33 Commits

Author SHA1 Message Date
openhands
c377f9e9ae Merge main into feature/runtime-manager 2025-01-03 20:37:08 +00:00
Robert Brennan
9a6084c6d5 only destroy runtime if no one is active 2024-12-24 16:13:45 -05:00
Robert Brennan
30c1d032e3 move close 2024-12-24 16:12:24 -05:00
Robert Brennan
615eabe5ed add log statement 2024-12-24 16:08:16 -05:00
Robert Brennan
3ecd214d69 Merge branch 'main' into feature/runtime-manager 2024-12-24 15:54:40 -05:00
Robert Brennan
c9a6402103 add cleanup logic 2024-12-24 15:52:58 -05:00
Robert Brennan
33a1dd89e7 remove conversation logic 2024-12-24 15:48:11 -05:00
Robert Brennan
d3f726df51 change connect logic 2024-12-24 15:33:04 -05:00
Robert Brennan
333f9a5bdf fix tests 2024-12-24 15:27:57 -05:00
Robert Brennan
0d454d46f2 Revert "refactor: move runtime creation logic to RuntimeManager.get_runtime()"
This reverts commit 42730014d5.
2024-12-24 15:16:32 -05:00
Robert Brennan
e7685f185c fix cli 2024-12-24 15:15:08 -05:00
Robert Brennan
749da6367e update cli 2024-12-24 15:14:18 -05:00
Robert Brennan
4b497c8e64 fix runtime_manager plumbing 2024-12-24 15:12:15 -05:00
openhands
42730014d5 refactor: move runtime creation logic to RuntimeManager.get_runtime() 2024-12-24 19:58:14 +00:00
openhands
81110671b2 refactor: use singleton RuntimeManager from shared.py 2024-12-24 19:57:06 +00:00
openhands
25f3349e1a refactor: use RuntimeManager to get existing runtime in Conversation 2024-12-24 19:52:35 +00:00
openhands
30f6166bf6 fix: Fix async mock in test_process_issue 2024-12-24 17:55:26 +00:00
Robert Brennan
1f706fe2f2 fix test 2024-12-24 12:32:02 -05:00
Robert Brennan
4123c65317 Merge branch 'main' into feature/runtime-manager 2024-12-24 12:04:33 -05:00
Robert Brennan
6dfd54be9f fix plumbing 2024-12-24 11:53:17 -05:00
openhands
8eef9b2563 Use server's shared config for RuntimeManager 2024-12-24 16:39:27 +00:00
openhands
5d5978c6cb Move runtime class resolution to RuntimeManager and remove redundant error callback 2024-12-24 16:35:14 +00:00
openhands
1a17972b4e Move RuntimeManager to module level and simplify config handling 2024-12-24 16:30:14 +00:00
openhands
4de7a4f85d Simplify RuntimeManager config handling and fix type issues 2024-12-24 16:25:24 +00:00
openhands
8befeca41d Fix linting issues and add missing await for create_runtime 2024-12-24 16:16:16 +00:00
openhands
918139e886 Move AppConfig to RuntimeManager class level and update initialization flow 2024-12-24 16:13:09 +00:00
openhands
6374174095 Update main.py to use RuntimeManager for runtime creation 2024-12-24 16:02:47 +00:00
Robert Brennan
138f6932eb move import 2024-12-24 10:08:13 -05:00
Robert Brennan
7181efd26d move import 2024-12-24 10:07:24 -05:00
Robert Brennan
3a52360ab0 remove exit 2024-12-24 10:05:07 -05:00
openhands
cd9eb1d85c Fix singleton import and add tests for RuntimeManager 2024-12-24 15:03:10 +00:00
openhands
ada657b476 Fix linting issues in runtime_manager.py 2024-12-24 14:54:11 +00:00
openhands
b630d65626 Add RuntimeManager for centralized runtime management 2024-12-24 14:52:19 +00:00
19 changed files with 253 additions and 125 deletions

View File

@@ -32,6 +32,8 @@ from openhands.events.observation import (
FileEditObservation,
NullObservation,
)
from openhands.runtime.base import Runtime
from openhands.runtime.runtime_manager import RuntimeManager
def display_message(message: str):
@@ -162,8 +164,6 @@ async def main(loop):
event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
await runtime.connect()
asyncio.create_task(prompt_for_next_task())
await run_agent_until_done(

View File

@@ -29,6 +29,7 @@ from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization.event import event_to_trajectory
from openhands.runtime.base import Runtime
from openhands.runtime.runtime_manager import RuntimeManager
class FakeUserResponseFunc(Protocol):
@@ -51,6 +52,7 @@ def read_task_from_stdin() -> str:
return sys.stdin.read()
async def run_controller(
config: AppConfig,
initial_user_action: Action,
@@ -79,8 +81,7 @@ async def run_controller(
sid = sid or generate_sid(config)
if runtime is None:
runtime = create_runtime(config, sid=sid, headless_mode=headless_mode)
await runtime.connect()
runtime = await create_runtime(config, sid=sid, headless_mode=headless_mode)
event_stream = runtime.event_stream

View File

@@ -199,8 +199,7 @@ async def process_issue(
)
config.set_llm_config(llm_config)
runtime = create_runtime(config)
await runtime.connect()
runtime = await create_runtime(config)
def on_event(evt):
logger.info(evt)

View File

@@ -0,0 +1,79 @@
from typing import Dict, List, Optional
from openhands.core.config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils.singleton import Singleton
class RuntimeManager(metaclass=Singleton):
def __init__(self, config: AppConfig):
self._runtimes: Dict[str, Runtime] = {}
self._config = config
@property
def config(self) -> AppConfig:
return self._config
async def create_runtime(
self,
event_stream: EventStream,
sid: str,
plugins: Optional[List[PluginRequirement]] = None,
env_vars: Optional[Dict[str, str]] = None,
status_callback=None,
attach_to_existing: bool = False,
headless_mode: bool = False,
) -> Runtime:
if sid in self._runtimes:
raise RuntimeError(f'Runtime with ID {sid} already exists')
runtime_class = get_runtime_cls(self.config.runtime)
logger.debug(f'Initializing runtime: {runtime_class.__name__}')
runtime = runtime_class(
config=self.config,
event_stream=event_stream,
sid=sid,
plugins=plugins,
env_vars=env_vars,
status_callback=status_callback,
attach_to_existing=attach_to_existing,
headless_mode=headless_mode,
)
try:
await runtime.connect()
except AgentRuntimeUnavailableError as e:
logger.error(f'Runtime initialization failed: {e}', exc_info=True)
if status_callback:
status_callback('error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e))
raise
self._runtimes[sid] = runtime
logger.info(
f'Created runtime with ID: {sid}. There are now {len(self._runtimes)} runtimes active.'
)
return runtime
def get_runtime(self, sid: str) -> Optional[Runtime]:
return self._runtimes.get(sid)
def list_runtimes(self) -> List[str]:
return list(self._runtimes.keys())
def destroy_runtime(self, sid: str) -> bool:
runtime = self._runtimes.get(sid)
if runtime:
del self._runtimes[sid]
runtime.close()
logger.info(f'Destroyed runtime with ID: {sid}')
return True
return False
async def destroy_all_runtimes(self):
for runtime_id in list(self._runtimes.keys()):
self.destroy_runtime(runtime_id)

View File

@@ -0,0 +1,14 @@
class Singleton(type):
"""Metaclass for creating singleton classes.
Usage:
class MyClass(metaclass=Singleton):
pass
"""
_instances: dict = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]

View File

@@ -15,6 +15,7 @@ from openhands.server.middleware import (
LocalhostCORSMiddleware,
NoCacheMiddleware,
RateLimitMiddleware,
session_manager,
)
from openhands.server.routes.conversation import app as conversation_api_router
from openhands.server.routes.feedback import app as feedback_api_router
@@ -26,7 +27,7 @@ from openhands.server.routes.manage_conversations import (
from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
from openhands.server.routes.settings import app as settings_router
from openhands.server.shared import openhands_config, session_manager
from openhands.server.shared import openhands_config
from openhands.utils.import_utils import get_impl

View File

@@ -14,9 +14,10 @@ from openhands.events.observation import (
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.middleware import session_manager
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.session.manager import ConversationDoesNotExistError
from openhands.server.shared import config, openhands_config, session_manager, sio
from openhands.server.shared import config, openhands_config, sio
from openhands.server.types import AppMode
from openhands.utils.async_utils import call_sync_from_async

View File

@@ -10,9 +10,12 @@ from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from openhands.server.shared import session_manager
from openhands.server.session import SessionManager
from openhands.server.shared import config, file_store, runtime_manager, sio
from openhands.server.types import SessionMiddlewareInterface
session_manager = SessionManager(sio, config, file_store)
class LocalhostCORSMiddleware(CORSMiddleware):
"""
@@ -134,10 +137,17 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
"""
Attach the user's session based on the provided authentication token.
"""
request.state.conversation = await session_manager.attach_to_conversation(
request.state.sid
)
if not request.state.conversation:
request.state.runtime = runtime_manager.get_runtime(request.state.sid)
if request.state.runtime is None:
event_stream = await session_manager.get_event_stream(request.state.sid)
if event_stream:
request.state.runtime = await runtime_manager.create_runtime(
event_stream=event_stream,
sid=request.state.sid,
attach_to_existing=True,
headless_mode=False,
)
if not request.state.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Session not found'},
@@ -148,7 +158,7 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
"""
Detach the user's session.
"""
await session_manager.detach_from_conversation(request.state.conversation)
pass
async def __call__(self, request: Request, call_next: Callable):
if not self._should_attach(request):

View File

@@ -13,7 +13,7 @@ async def get_remote_runtime_config(request: Request):
Currently, this is the session ID and runtime ID (if available).
"""
runtime = request.state.conversation.runtime
runtime = request.state.runtime
runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None
session_id = runtime.sid if hasattr(runtime, 'sid') else None
return JSONResponse(
@@ -37,7 +37,7 @@ async def get_vscode_url(request: Request):
JSONResponse: A JSON response indicating the success of the operation.
"""
try:
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = request.state.runtime
logger.debug(f'Runtime type: {type(runtime)}')
logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}')
return JSONResponse(status_code=200, content={'vscode_url': runtime.vscode_url})
@@ -81,12 +81,12 @@ async def search_events(
HTTPException: If conversation is not found
ValueError: If limit is less than 1 or greater than 100
"""
if not request.state.conversation:
if not request.state.runtime:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Conversation not found'
)
# Get matching events from the stream
event_stream = request.state.conversation.event_stream
event_stream = request.state.runtime.event_stream
matching_events = event_stream.get_matching_events(
query=query,
event_type=event_type,

View File

@@ -35,7 +35,7 @@ async def submit_feedback(request: Request, conversation_id: str):
# and there is a function to handle the storage.
body = await request.json()
async_stream = AsyncEventStreamWrapper(
request.state.conversation.event_stream, filter_hidden=True
request.state.runtime.event_stream, filter_hidden=True
)
trajectory = []
async for event in async_stream:

View File

@@ -58,13 +58,13 @@ async def list_files(request: Request, conversation_id: str, path: str | None =
Raises:
HTTPException: If there's an error listing the files.
"""
if not request.state.conversation.runtime:
if not request.state.runtime:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'Runtime not yet initialized'},
)
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = request.state.runtime
try:
file_list = await call_sync_from_async(runtime.list_files, path)
except AgentRuntimeUnavailableError as e:
@@ -124,7 +124,7 @@ async def select_file(file: str, request: Request):
Raises:
HTTPException: If there's an error opening the file.
"""
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = request.state.runtime
file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
read_action = FileReadAction(file)
@@ -199,7 +199,7 @@ async def upload_file(request: Request, conversation_id: str, files: list[Upload
tmp_file.write(file_contents)
tmp_file.flush()
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = request.state.runtime
try:
await call_sync_from_async(
runtime.copy_to,
@@ -276,7 +276,7 @@ async def save_file(request: Request):
raise HTTPException(status_code=400, detail='Missing filePath or content')
# Save the file to the agent's runtime file store
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = request.state.runtime
file_path = os.path.join(
runtime.config.workspace_mount_path_in_sandbox, file_path
)
@@ -316,7 +316,7 @@ async def zip_current_workspace(
):
try:
logger.debug('Zipping workspace')
runtime: Runtime = request.state.conversation.runtime
runtime: Runtime = request.state.runtime
path = runtime.config.workspace_mount_path_in_sandbox
try:
zip_file = await call_sync_from_async(runtime.copy_from, path)

View File

@@ -0,0 +1,77 @@
import uuid
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from github import Github
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.server.middleware import session_manager
from openhands.server.routes.settings import SettingsStoreImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config
from openhands.storage.conversation.conversation_store import (
ConversationMetadata,
ConversationStore,
)
from openhands.utils.async_utils import call_sync_from_async
app = APIRouter(prefix='/api')
class InitSessionRequest(BaseModel):
github_token: str | None = None
latest_event_id: int = -1
selected_repository: str | None = None
args: dict | None = None
@app.post('/conversations')
async def new_conversation(request: Request, data: InitSessionRequest):
"""Initialize a new session or join an existing one.
After successful initialization, the client should connect to the WebSocket
using the returned conversation ID
"""
github_token = ''
if data.github_token:
github_token = data.github_token
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings = await settings_store.load()
session_init_args: dict = {}
if settings:
session_init_args = {**settings.__dict__, **session_init_args}
if data.args:
for key, value in data.args.items():
session_init_args[key.lower()] = value
session_init_args['github_token'] = github_token
session_init_args['selected_repository'] = data.selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
conversation_store = await ConversationStore.get_instance(config)
conversation_id = uuid.uuid4().hex
while await conversation_store.exists(conversation_id):
logger.warning(f'Collision on conversation ID: {conversation_id}. Retrying...')
conversation_id = uuid.uuid4().hex
user_id = ''
if data.github_token:
g = Github(data.github_token)
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id
await conversation_store.save_metadata(
ConversationMetadata(
conversation_id=conversation_id,
github_user_id=user_id,
selected_repository=data.selected_repository,
)
)
await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data
)
return JSONResponse(content={'status': 'ok', 'conversation_id': conversation_id})

View File

@@ -4,6 +4,9 @@ from fastapi import (
Request,
)
from openhands.security import SecurityAnalyzer, options
from openhands.server.shared import config
app = APIRouter(prefix='/api/conversations/{conversation_id}')
@@ -22,9 +25,10 @@ async def security_api(request: Request):
Raises:
HTTPException: If the security analyzer is not initialized.
"""
if not request.state.conversation.security_analyzer:
if not request.state.runtime:
raise HTTPException(status_code=404, detail='Security analyzer not initialized')
security_analyzer = options.SecurityAnalyzers.get(
config.security.security_analyzer or '', SecurityAnalyzer
)(request.state.runtime.event_stream)
return await request.state.conversation.security_analyzer.handle_api_request(
request
)
return await security_analyzer.handle_api_request(request)

View File

@@ -4,7 +4,7 @@ from typing import Callable, Optional
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig, AppConfig, LLMConfig
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
@@ -15,6 +15,7 @@ from openhands.microagent import BaseMicroAgent
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.server.shared import runtime_manager
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
@@ -60,8 +61,6 @@ class AgentSession:
async def start(
self,
runtime_name: str,
config: AppConfig,
agent: Agent,
max_iterations: int,
max_budget_per_task: float | None = None,
@@ -72,8 +71,6 @@ class AgentSession:
):
"""Starts the Agent session
Parameters:
- runtime_name: The name of the runtime associated with the session
- config:
- agent:
- max_iterations:
- max_budget_per_task:
@@ -85,14 +82,15 @@ class AgentSession:
'Session already started. You need to close this session and start a new one.'
)
if self._closed:
logger.warning('Session closed before starting')
return
self._initializing = True
self._create_security_analyzer(config.security.security_analyzer)
self._create_security_analyzer(
runtime_manager.config.security.security_analyzer
)
await self._create_runtime(
runtime_name=runtime_name,
config=config,
agent=agent,
github_token=github_token,
selected_repository=selected_repository,
@@ -100,7 +98,7 @@ class AgentSession:
self.controller = self._create_controller(
agent,
config.security.confirmation_mode,
runtime_manager.config.security.confirmation_mode,
max_iterations,
max_budget_per_task=max_budget_per_task,
agent_to_llm_config=agent_to_llm_config,
@@ -136,7 +134,7 @@ class AgentSession:
end_state.save_to_session(self.sid, self.file_store)
await self.controller.close()
if self.runtime is not None:
self.runtime.close()
runtime_manager.destroy_runtime(self.sid)
if self.security_analyzer is not None:
await self.security_analyzer.close()
@@ -159,8 +157,6 @@ class AgentSession:
async def _create_runtime(
self,
runtime_name: str,
config: AppConfig,
agent: Agent,
github_token: str | None = None,
selected_repository: str | None = None,
@@ -168,38 +164,28 @@ class AgentSession:
"""Creates a runtime instance
Parameters:
- runtime_name: The name of the runtime associated with the session
- config:
- agent:
"""
if self.runtime is not None:
raise RuntimeError('Runtime already created')
logger.debug(f'Initializing runtime `{runtime_name}` now...')
runtime_cls = get_runtime_cls(runtime_name)
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
sid=self.sid,
plugins=agent.sandbox_plugins,
status_callback=self._status_callback,
headless_mode=False,
)
# FIXME: this sleep is a terrible hack.
# This is to give the websocket a second to connect, so that
# the status messages make it through to the frontend.
# We should find a better way to plumb status messages through.
await asyncio.sleep(1)
try:
await self.runtime.connect()
self.runtime = await runtime_manager.create_runtime(
event_stream=self.event_stream,
sid=self.sid,
plugins=agent.sandbox_plugins,
status_callback=self._status_callback,
headless_mode=False,
)
except AgentRuntimeUnavailableError as e:
logger.error(f'Runtime initialization failed: {e}', exc_info=True)
if self._status_callback:
self._status_callback(
'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e)
)
return
self.runtime.clone_repo(github_token, selected_repository)

View File

@@ -1,46 +0,0 @@
import asyncio
from openhands.core.config import AppConfig
from openhands.events.stream import EventStream
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
from openhands.security import SecurityAnalyzer, options
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_sync_from_async
class Conversation:
sid: str
file_store: FileStore
event_stream: EventStream
runtime: Runtime
def __init__(
self,
sid: str,
file_store: FileStore,
config: AppConfig,
):
self.sid = sid
self.config = config
self.file_store = file_store
self.event_stream = EventStream(sid, file_store)
if config.security.security_analyzer:
self.security_analyzer = options.SecurityAnalyzers.get(
config.security.security_analyzer, SecurityAnalyzer
)(self.event_stream)
runtime_cls = get_runtime_cls(self.config.runtime)
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
sid=self.sid,
attach_to_existing=True,
headless_mode=False,
)
async def connect(self):
await self.runtime.connect()
async def disconnect(self):
asyncio.create_task(call_sync_from_async(self.runtime.close))

View File

@@ -7,12 +7,13 @@ from uuid import uuid4
import socketio
from openhands.core.config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStream, session_exists
from openhands.server.session.conversation import Conversation
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.settings import Settings
from openhands.server.shared import runtime_manager
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import should_continue
@@ -20,9 +21,6 @@ from openhands.utils.shutdown_listener import should_continue
_REDIS_POLL_TIMEOUT = 1.5
_CHECK_ALIVE_INTERVAL = 15
_CLEANUP_INTERVAL = 15
_CLEANUP_EXCEPTION_WAIT_TIME = 15
class ConversationDoesNotExistError(Exception):
pass
@@ -64,16 +62,12 @@ class SessionManager:
redis_client = self._get_redis_client()
if redis_client:
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
self._cleanup_task = asyncio.create_task(self._cleanup_detached_conversations())
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._redis_listen_task:
self._redis_listen_task.cancel()
self._redis_listen_task = None
if self._cleanup_task:
self._cleanup_task.cancel()
self._cleanup_task = None
def _get_redis_client(self):
redis_client = getattr(self.sio.manager, 'redis', None)
@@ -161,6 +155,7 @@ class SessionManager:
# which can't be guaranteed - nodes can simply vanish unexpectedly!
sid = data['sid']
logger.debug(f'session_closing:{sid}')
await call_sync_from_async(runtime_manager.destroy_runtime, sid)
for (
connection_id,
local_sid,
@@ -209,7 +204,7 @@ class SessionManager:
logger.info(f'join_conversation:{sid}:{connection_id}')
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
event_stream = await self.get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings)
return event_stream
@@ -348,20 +343,23 @@ class SessionManager:
if not await self.is_agent_loop_running(sid):
logger.info(f'start_agent_loop:{sid}')
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
sid=sid,
file_store=self.file_store,
config=self.config,
sio=self.sio,
)
self._local_agent_loops_by_sid[sid] = session
asyncio.create_task(session.initialize_agent(settings))
event_stream = await self._get_event_stream(sid)
event_stream = await self.get_event_stream(sid)
if not event_stream:
logger.error(f'No event stream after starting agent loop: {sid}')
raise RuntimeError(f'no_event_stream:{sid}')
asyncio.create_task(self._cleanup_session_later(sid))
return event_stream
async def _get_event_stream(self, sid: str) -> EventStream | None:
logger.info(f'_get_event_stream:{sid}')
async def get_event_stream(self, sid: str) -> EventStream | None:
logger.info(f'get_event_stream:{sid}')
session = self._local_agent_loops_by_sid.get(sid)
if session:
logger.info(f'found_local_agent_loop:{sid}')
@@ -444,6 +442,7 @@ class SessionManager:
if redis_client and await self._has_remote_connections(sid):
return False
await call_sync_from_async(runtime_manager.destroy_runtime, sid)
# We alert the cluster in case they are interested
if redis_client:
await redis_client.publish(

View File

@@ -50,7 +50,9 @@ class Session:
self.last_active_ts = int(time.time())
self.file_store = file_store
self.agent_session = AgentSession(
sid, file_store, status_callback=self.queue_status_message
sid,
file_store,
status_callback=self.queue_status_message,
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
@@ -102,8 +104,6 @@ class Session:
try:
await self.agent_session.start(
runtime_name=self.config.runtime,
config=self.config,
agent=agent,
max_iterations=max_iterations,
max_budget_per_task=self.config.max_budget_per_task,

View File

@@ -4,8 +4,8 @@ import socketio
from dotenv import load_dotenv
from openhands.core.config import load_app_config
from openhands.runtime.runtime_manager import RuntimeManager
from openhands.server.config.openhands_config import load_openhands_config
from openhands.server.session import SessionManager
from openhands.storage import get_file_store
load_dotenv()
@@ -27,4 +27,4 @@ sio = socketio.AsyncServer(
async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager
)
session_manager = SessionManager(sio, config, file_store)
runtime_manager = RuntimeManager(config)

View File

@@ -326,7 +326,8 @@ async def test_complete_runtime():
@pytest.mark.asyncio
async def test_process_issue(mock_output_dir, mock_prompt_template):
# Mock dependencies
mock_create_runtime = MagicMock()
mock_runtime = MagicMock(connect=AsyncMock())
mock_create_runtime = AsyncMock(return_value=mock_runtime)
mock_initialize_runtime = AsyncMock()
mock_run_controller = AsyncMock()
mock_complete_runtime = AsyncMock()
@@ -408,7 +409,9 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
handler_instance.reset_mock()
# Mock return values
mock_create_runtime.return_value = MagicMock(connect=AsyncMock())
mock_runtime = MagicMock(connect=AsyncMock())
mock_create_runtime.return_value = AsyncMock()
mock_create_runtime.return_value.__aenter__.return_value = mock_runtime
if test_case['run_controller_raises']:
mock_run_controller.side_effect = test_case['run_controller_raises']
else: