Compare commits

...

3 Commits

Author SHA1 Message Date
Ray Myers
9e578e9947 Revert additional returns, preserve existing behavior 2025-06-05 04:34:16 -05:00
Ray Myers
321af8d825 Fix type errors in conversation manager and utils (#8912)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-06-05 04:16:18 -05:00
openhands
75c797fe4a Add stricter type checking for conversation manager and utils 2025-06-05 08:42:57 +00:00
5 changed files with 74 additions and 39 deletions

View File

@@ -7,3 +7,11 @@ warn_unreachable = True
warn_redundant_casts = True
no_implicit_optional = True
strict_optional = True
[mypy-openhands.server.utils]
disallow_incomplete_defs = True
disallow_untyped_defs = True
[mypy-openhands.server.conversation_manager.*]
disallow_incomplete_defs = True
disallow_untyped_defs = True

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import types
from abc import ABC, abstractmethod
import socketio
@@ -47,11 +48,16 @@ class ConversationManager(ABC):
conversation_store: ConversationStore
@abstractmethod
async def __aenter__(self):
async def __aenter__(self) -> 'ConversationManager':
"""Initialize the conversation manager."""
@abstractmethod
async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
"""Clean up the conversation manager."""
@abstractmethod
@@ -61,7 +67,7 @@ class ConversationManager(ABC):
"""Attach to an existing conversation or create a new one."""
@abstractmethod
async def detach_from_conversation(self, conversation: ServerConversation):
async def detach_from_conversation(self, conversation: ServerConversation) -> None:
"""Detach from a conversation."""
@abstractmethod
@@ -103,15 +109,15 @@ class ConversationManager(ABC):
"""Start an event loop if one is not already running"""
@abstractmethod
async def send_to_event_stream(self, connection_id: str, data: dict):
async def send_to_event_stream(self, connection_id: str, data: dict) -> None:
"""Send data to an event stream."""
@abstractmethod
async def disconnect_from_session(self, connection_id: str):
async def disconnect_from_session(self, connection_id: str) -> None:
"""Disconnect from a session."""
@abstractmethod
async def close_session(self, sid: str):
async def close_session(self, sid: str) -> None:
"""Close a session."""
@abstractmethod

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import hashlib
import os
import types
from base64 import urlsafe_b64encode
from dataclasses import dataclass, field
from types import MappingProxyType
@@ -56,11 +57,16 @@ class DockerNestedConversationManager(ConversationManager):
_starting_conversation_ids: set[str] = field(default_factory=set)
_runtime_container_image: str | None = None
async def __aenter__(self):
async def __aenter__(self) -> 'DockerNestedConversationManager':
# No action is required on startup for this implementation
pass
return self
async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
# No action is required on shutdown for this implementation
pass
@@ -70,7 +76,7 @@ class DockerNestedConversationManager(ConversationManager):
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
async def detach_from_conversation(self, conversation: ServerConversation):
async def detach_from_conversation(self, conversation: ServerConversation) -> None:
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
@@ -148,7 +154,7 @@ class DockerNestedConversationManager(ConversationManager):
user_id: str | None,
initial_user_msg: MessageAction | None,
replay_json: str | None,
):
) -> None:
logger.info(f'starting_agent_loop:{sid}', extra={'session_id': sid})
await self.ensure_num_conversations_below_limit(sid, user_id)
runtime = await self._create_runtime(sid, user_id, settings)
@@ -190,7 +196,7 @@ class DockerNestedConversationManager(ConversationManager):
initial_user_msg: MessageAction | None,
replay_json: str | None,
api_url: str,
):
) -> None:
try:
await call_sync_from_async(runtime.wait_until_alive)
await call_sync_from_async(runtime.setup_initial_env)
@@ -271,18 +277,20 @@ class DockerNestedConversationManager(ConversationManager):
finally:
self._starting_conversation_ids.discard(sid)
async def send_to_event_stream(self, connection_id: str, data: dict):
async def send_to_event_stream(self, connection_id: str, data: dict) -> None:
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
async def disconnect_from_session(self, connection_id: str):
async def disconnect_from_session(self, connection_id: str) -> None:
# Not supported - clients should connect directly to the nested server!
raise ValueError('unsupported_operation')
async def close_session(self, sid: str):
async def close_session(self, sid: str) -> None:
stop_all_containers(f'openhands-runtime-{sid}')
async def get_agent_loop_info(self, user_id=None, filter_to_sids=None):
async def get_agent_loop_info(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> list[AgentLoopInfo]:
results = []
containers = self.docker_client.containers.list()
for container in containers:
@@ -353,7 +361,7 @@ class DockerNestedConversationManager(ConversationManager):
nested_url = f'{self.config.sandbox.local_runtime_url}:{container_port}/api/conversations/{conversation_id}'
return nested_url
def _get_session_api_key_for_conversation(self, conversation_id: str):
def _get_session_api_key_for_conversation(self, conversation_id: str) -> str:
jwt_secret = self.config.jwt_secret.get_secret_value() # type:ignore
conversation_key = f'{jwt_secret}:{conversation_id}'.encode()
session_api_key = (
@@ -363,7 +371,9 @@ class DockerNestedConversationManager(ConversationManager):
)
return session_api_key
async def ensure_num_conversations_below_limit(self, sid: str, user_id: str | None):
async def ensure_num_conversations_below_limit(
self, sid: str, user_id: str | None
) -> None:
response_ids = await self.get_running_agent_loops(user_id)
if len(response_ids) >= self.config.max_concurrent_conversations:
logger.info(
@@ -395,7 +405,7 @@ class DockerNestedConversationManager(ConversationManager):
)
await self.close_session(oldest_conversation_id)
def _get_provider_handler(self, settings: Settings):
def _get_provider_handler(self, settings: Settings) -> ProviderHandler:
provider_tokens = None
if isinstance(settings, ConversationInitData):
provider_tokens = settings.git_provider_tokens
@@ -405,7 +415,9 @@ class DockerNestedConversationManager(ConversationManager):
)
return provider_handler
async def _create_runtime(self, sid: str, user_id: str | None, settings: Settings):
async def _create_runtime(
self, sid: str, user_id: str | None, settings: Settings
) -> DockerRuntime:
# This session is created here only because it is the easiest way to get a runtime, which
# is the easiest way to create the needed docker container
session = Session(
@@ -437,7 +449,7 @@ class DockerNestedConversationManager(ConversationManager):
env_vars['SESSION_API_KEY'] = self._get_session_api_key_for_conversation(sid)
# We need to be able to specify the nested conversation id within the nested runtime
env_vars['ALLOW_SET_CONVERSATION_ID'] = '1'
env_vars['WORKSPACE_BASE'] = f'/workspace'
env_vars['WORKSPACE_BASE'] = '/workspace'
env_vars['SANDBOX_CLOSE_DELAY'] = '0'
# Set up mounted volume for conversation directory within workspace
@@ -483,7 +495,7 @@ class DockerNestedConversationManager(ConversationManager):
await call_sync_from_async(container.start())
return True
return False
except docker.errors.NotFound as e:
except docker.errors.NotFound:
return False

View File

@@ -1,8 +1,9 @@
import asyncio
import time
import types
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Iterable
from typing import Any, Callable, Iterable
import socketio
@@ -61,11 +62,16 @@ class StandaloneConversationManager(ConversationManager):
_cleanup_task: asyncio.Task | None = None
_conversation_store_class: type[ConversationStore] | None = None
async def __aenter__(self):
async def __aenter__(self) -> 'StandaloneConversationManager':
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
return self
async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
if self._cleanup_task:
self._cleanup_task.cancel()
self._cleanup_task = None
@@ -132,7 +138,7 @@ class StandaloneConversationManager(ConversationManager):
agent_loop_info = await self.maybe_start_agent_loop(sid, settings, user_id)
return agent_loop_info
async def detach_from_conversation(self, conversation: ServerConversation):
async def detach_from_conversation(self, conversation: ServerConversation) -> None:
sid = conversation.sid
async with self._conversations_lock:
if sid in self._active_conversations:
@@ -144,7 +150,7 @@ class StandaloneConversationManager(ConversationManager):
self._active_conversations.pop(sid)
self._detached_conversations[sid] = (conversation, time.time())
async def _cleanup_stale(self):
async def _cleanup_stale(self) -> None:
while should_continue():
try:
async with self._conversations_lock:
@@ -324,7 +330,7 @@ class StandaloneConversationManager(ConversationManager):
pass # Already subscribed - take no action
return session
async def send_to_event_stream(self, connection_id: str, data: dict):
async def send_to_event_stream(self, connection_id: str, data: dict) -> None:
# If there is a local session running, send to that
sid = self._local_connection_id_to_session_id.get(connection_id)
if not sid:
@@ -337,7 +343,7 @@ class StandaloneConversationManager(ConversationManager):
raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
async def disconnect_from_session(self, connection_id: str):
async def disconnect_from_session(self, connection_id: str) -> None:
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
logger.info(
f'disconnect_from_session:{connection_id}:{sid}', extra={'session_id': sid}
@@ -350,12 +356,12 @@ class StandaloneConversationManager(ConversationManager):
)
return
async def close_session(self, sid: str):
async def close_session(self, sid: str) -> None:
session = self._local_agent_loops_by_sid.get(sid)
if session:
await self._close_session(sid)
async def _close_session(self, sid: str):
async def _close_session(self, sid: str) -> None:
logger.info(f'_close_session:{sid}', extra={'session_id': sid})
# Clear up local variables
@@ -402,8 +408,8 @@ class StandaloneConversationManager(ConversationManager):
user_id: str | None,
conversation_id: str,
settings: Settings,
) -> Callable:
def callback(event, *args, **kwargs):
) -> Callable[[Any], None]:
def callback(event: Any) -> None:
call_async_from_sync(
self._update_conversation_for_event,
GENERAL_TIMEOUT,
@@ -420,8 +426,8 @@ class StandaloneConversationManager(ConversationManager):
user_id: str,
conversation_id: str,
settings: Settings,
event=None,
):
event: Any = None,
) -> None:
conversation_store = await self._get_conversation_store(user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)
@@ -473,7 +479,7 @@ class StandaloneConversationManager(ConversationManager):
async def get_agent_loop_info(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
):
) -> list[AgentLoopInfo]:
results = []
for session in self._local_agent_loops_by_sid.values():
if user_id and session.user_id != user_id:
@@ -483,7 +489,7 @@ class StandaloneConversationManager(ConversationManager):
results.append(self._agent_loop_info_from_session(session))
return results
def _agent_loop_info_from_session(self, session: Session):
def _agent_loop_info_from_session(self, session: Session) -> AgentLoopInfo:
return AgentLoopInfo(
conversation_id=session.sid,
url=self._get_conversation_url(session.sid),
@@ -491,7 +497,7 @@ class StandaloneConversationManager(ConversationManager):
event_store=session.agent_session.event_stream,
)
def _get_conversation_url(self, conversation_id: str):
def _get_conversation_url(self, conversation_id: str) -> str:
return f'/api/conversations/{conversation_id}'

View File

@@ -1,5 +1,8 @@
from typing import AsyncGenerator
from fastapi import Depends, Request
from openhands.server.session.conversation import ServerConversation
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
from openhands.server.user_auth import get_user_id
from openhands.storage.conversation.conversation_store import ConversationStore
@@ -19,7 +22,7 @@ async def get_conversation_store(request: Request) -> ConversationStore | None:
async def get_conversation(
conversation_id: str, user_id: str | None = Depends(get_user_id)
):
) -> AsyncGenerator[ServerConversation | None, None]:
"""Grabs conversation id set by middleware. Adds the conversation_id to the openapi schema."""
conversation = await conversation_manager.attach_to_conversation(
conversation_id, user_id