ALL-4634: implement public conversation sharing feature (#12044)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2025-12-26 10:02:01 -07:00
committed by GitHub
parent cb8c1fa263
commit a829d10213
30 changed files with 2191 additions and 250 deletions

View File

@@ -45,6 +45,8 @@ class AppConversationInfo(BaseModel):
parent_conversation_id: OpenHandsUUID | None = None
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
public: bool | None = None
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
@@ -114,6 +116,12 @@ class AppConversationStartRequest(BaseModel):
parent_conversation_id: OpenHandsUUID | None = None
agent_type: AgentType = Field(default=AgentType.DEFAULT)
public: bool | None = None
class AppConversationUpdateRequest(BaseModel):
public: bool
class AppConversationStartTaskStatus(Enum):
WORKING = 'WORKING'

View File

@@ -40,6 +40,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTask,
AppConversationStartTaskPage,
AppConversationStartTaskSortOrder,
AppConversationUpdateRequest,
SkillResponse,
)
from openhands.app_server.app_conversation.app_conversation_service import (
@@ -222,6 +223,22 @@ async def start_app_conversation(
raise
@router.patch('/{conversation_id}')
async def update_app_conversation(
conversation_id: str,
update_request: AppConversationUpdateRequest,
app_conversation_service: AppConversationService = (
app_conversation_service_dependency
),
) -> AppConversation:
info = await app_conversation_service.update_app_conversation(
UUID(conversation_id), update_request
)
if info is None:
raise HTTPException(404, 'unknown_app_conversation')
return info
@router.post('/stream-start')
async def stream_app_conversation_start(
request: AppConversationStartRequest,

View File

@@ -10,6 +10,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationSortOrder,
AppConversationStartRequest,
AppConversationStartTask,
AppConversationUpdateRequest,
)
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.services.injector import Injector
@@ -98,6 +99,13 @@ class AppConversationService(ABC):
"""Run the setup scripts for the project and yield status updates"""
yield task
@abstractmethod
async def update_app_conversation(
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist."""
@abstractmethod
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
"""Delete a V1 conversation and all its associated data.

View File

@@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
AppConversationStartTask,
AppConversationStartTaskStatus,
AppConversationUpdateRequest,
)
from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
@@ -1049,6 +1050,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
f'Successfully updated agent-server conversation {conversation_id} title to "{new_title}"'
)
async def update_app_conversation(
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist."""
info = await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
if info is None:
return None
for field_name in request.model_fields:
value = getattr(request, field_name)
setattr(info, field_name, value)
info = await self.app_conversation_info_service.save_app_conversation_info(info)
conversations = await self._build_app_conversations([info])
return conversations[0]
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
"""Delete a V1 conversation and all its associated data.

View File

@@ -25,7 +25,17 @@ from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import Column, DateTime, Float, Integer, Select, String, func, select
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
Integer,
Select,
String,
func,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.agent_server.utils import utc_now
@@ -91,6 +101,7 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_version = Column(String, nullable=False, default='V0', index=True)
sandbox_id = Column(String, nullable=True, index=True)
parent_conversation_id = Column(String, nullable=True, index=True)
public = Column(Boolean, nullable=True, index=True)
@dataclass
@@ -350,6 +361,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
if info.parent_conversation_id
else None
),
public=info.public,
)
await self.db_session.merge(stored)
@@ -541,6 +553,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
else None
),
sub_conversation_ids=sub_conversation_ids or [],
public=stored.public,
created_at=created_at,
updated_at=updated_at,
)

View File

@@ -0,0 +1,41 @@
"""add public column to conversation_metadata
Revision ID: 004
Revises: 003
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '004'
down_revision: Union[str, None] = '003'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('public', sa.Boolean(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_public'),
'conversation_metadata',
['public'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_public'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'public')

View File

@@ -22,7 +22,7 @@ event_service_dependency = depends_event_service()
@router.get('/search')
async def search_events(
conversation_id__eq: Annotated[
UUID | None,
str | None,
Query(title='Optional filter by conversation ID'),
] = None,
kind__eq: Annotated[
@@ -55,7 +55,7 @@ async def search_events(
assert limit > 0
assert limit <= 100
return await event_service.search_events(
conversation_id__eq=conversation_id__eq,
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
@@ -68,7 +68,7 @@ async def search_events(
@router.get('/count')
async def count_events(
conversation_id__eq: Annotated[
UUID | None,
str | None,
Query(title='Optional filter by conversation ID'),
] = None,
kind__eq: Annotated[
@@ -91,7 +91,7 @@ async def count_events(
) -> int:
"""Count events matching the given filters."""
return await event_service.count_events(
conversation_id__eq=conversation_id__eq,
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,

View File

@@ -1,32 +1,27 @@
"""Filesystem-based EventService implementation."""
import asyncio
import glob
import json
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.app_conversation.app_conversation_info_service import (
AppConversationInfoService,
)
from openhands.app_server.errors import OpenHandsError
from openhands.app_server.event.event_service import EventService, EventServiceInjector
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.event.filesystem_event_service_base import (
FilesystemEventServiceBase,
)
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
_logger = logging.getLogger(__name__)
@dataclass
class FilesystemEventService(EventService):
class FilesystemEventService(FilesystemEventServiceBase, EventService):
"""Filesystem-based implementation of EventService.
Events are stored in files with the naming format:
@@ -47,25 +42,6 @@ class FilesystemEventService(EventService):
events_path.mkdir(parents=True, exist_ok=True)
return events_path
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
"""Convert timestamp to YYYYMMDDHHMMSS format."""
if isinstance(timestamp, str):
# Parse ISO format timestamp string
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
return dt.strftime('%Y%m%d%H%M%S')
return timestamp.strftime('%Y%m%d%H%M%S')
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
timestamp_str = self._timestamp_to_str(event.timestamp)
kind = event.__class__.__name__
# Handle both UUID objects and string UUIDs
if isinstance(event.id, str):
id_hex = event.id.replace('-', '')
else:
id_hex = event.id.hex
return f'{timestamp_str}_{kind}_{id_hex}'
def _save_event_to_file(self, conversation_id: UUID, event: Event) -> None:
"""Save an event to a file."""
events_path = self._ensure_events_dir(conversation_id)
@@ -77,60 +53,17 @@ class FilesystemEventService(EventService):
data = event.model_dump(mode='json')
f.write(json.dumps(data, indent=2))
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
events = []
for file_path in file_paths:
event = self._load_event_from_file(file_path)
if event is not None:
events.append(event)
return events
def _load_event_from_file(self, filepath: Path) -> Event | None:
"""Load an event from a file."""
try:
json_data = filepath.read_text()
return Event.model_validate_json(json_data)
except Exception:
return None
def _get_event_files_by_pattern(
self, pattern: str, conversation_id: UUID | None = None
) -> list[Path]:
"""Get event files matching a glob pattern, sorted by timestamp."""
if conversation_id:
search_path = self.events_dir / str(conversation_id) / pattern
else:
search_path = self.events_dir / '*' / pattern
files = glob.glob(str(search_path))
return sorted([Path(f) for f in files])
def _parse_filename(self, filename: str) -> dict[str, str] | None:
"""Parse filename to extract timestamp, kind, and event_id."""
try:
parts = filename.split('_')
if len(parts) >= 3:
timestamp_str = parts[0]
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
event_id = parts[-1]
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
except Exception:
pass
return None
def _get_conversation_id(self, file: Path) -> UUID | None:
try:
return UUID(file.parent.name)
except Exception:
return None
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
result = set()
for file in files:
conversation_id = self._get_conversation_id(file)
if conversation_id:
result.add(conversation_id)
return result
async def save_event(self, conversation_id: UUID, event: Event):
"""Save an event. Internal method intended not be part of the REST api."""
conversation = (
await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
)
if not conversation:
# This is either an illegal state or somebody is trying to hack
raise OpenHandsError('No such conversation: {conversaiont_id}')
self._save_event_to_file(conversation_id, event)
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
conversation_ids = list(self._get_conversation_ids(files))
@@ -150,161 +83,6 @@ class FilesystemEventService(EventService):
]
return result
def _filter_files_by_criteria(
self,
files: list[Path],
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
) -> list[Path]:
"""Filter files based on search criteria."""
filtered_files = []
for file_path in files:
# Check conversation_id filter
if conversation_id__eq:
if str(conversation_id__eq) not in str(file_path):
continue
# Parse filename for additional filtering
filename_info = self._parse_filename(file_path.name)
if not filename_info:
continue
# Check kind filter
if kind__eq and filename_info['kind'] != kind__eq:
continue
# Check timestamp filters
if timestamp__gte or timestamp__lt:
try:
file_timestamp = datetime.strptime(
filename_info['timestamp'], '%Y%m%d%H%M%S'
)
if timestamp__gte and file_timestamp < timestamp__gte:
continue
if timestamp__lt and file_timestamp >= timestamp__lt:
continue
except ValueError:
continue
filtered_files.append(file_path)
return filtered_files
async def get_event(self, event_id: str) -> Event | None:
"""Get the event with the given id, or None if not found."""
# Convert event_id to hex format (remove dashes) for filename matching
if isinstance(event_id, str) and '-' in event_id:
id_hex = event_id.replace('-', '')
else:
id_hex = event_id
# Use glob pattern to find files ending with the event_id
pattern = f'*_{id_hex}'
files = self._get_event_files_by_pattern(pattern)
if not files:
return None
# If there is no access to the conversation do not return the event
file = files[0]
conversation_id = self._get_conversation_id(file)
if not conversation_id:
return None
conversation = (
await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
)
if not conversation:
return None
# Load and return the first matching event
return self._load_event_from_file(file)
async def search_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search for events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
files.sort(
key=lambda f: f.name,
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
)
# Handle pagination
start_index = 0
if page_id:
for i, file_path in enumerate(files):
if file_path.name == page_id:
start_index = i + 1
break
# Collect items for this page
page_files = files[start_index : start_index + limit]
next_page_id = None
if start_index + limit < len(files):
next_page_id = files[start_index + limit].name
# Load all events from files in a background thread.
loop = asyncio.get_running_loop()
page_events = await loop.run_in_executor(
None, self._load_events_from_files, page_files
)
return EventPage(items=page_events, next_page_id=next_page_id)
async def count_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
return len(files)
async def save_event(self, conversation_id: UUID, event: Event):
"""Save an event. Internal method intended not be part of the REST api."""
conversation = (
await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
)
if not conversation:
# This is either an illegal state or somebody is trying to hack
raise OpenHandsError('No such conversation: {conversaiont_id}')
self._save_event_to_file(conversation_id, event)
class FilesystemEventServiceInjector(EventServiceInjector):
async def inject(

View File

@@ -0,0 +1,224 @@
import asyncio
import glob
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from uuid import UUID
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
class FilesystemEventServiceBase:
events_dir: Path
async def get_event(self, event_id: str) -> Event | None:
"""Get the event with the given id, or None if not found."""
# Convert event_id to hex format (remove dashes) for filename matching
if isinstance(event_id, str) and '-' in event_id:
id_hex = event_id.replace('-', '')
else:
id_hex = event_id
# Use glob pattern to find files ending with the event_id
pattern = f'*_{id_hex}'
files = self._get_event_files_by_pattern(pattern)
files = await self._filter_files_by_conversation(files)
if not files:
return None
# Load and return the first matching event
return self._load_event_from_file(files[0])
async def search_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search for events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
files.sort(
key=lambda f: f.name,
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
)
# Handle pagination
start_index = 0
if page_id:
for i, file_path in enumerate(files):
if file_path.name == page_id:
start_index = i + 1
break
# Collect items for this page
page_files = files[start_index : start_index + limit]
next_page_id = None
if start_index + limit < len(files):
next_page_id = files[start_index + limit].name
# Load all events from files in a background thread.
loop = asyncio.get_running_loop()
page_events = await loop.run_in_executor(
None, self._load_events_from_files, page_files
)
return EventPage(items=page_events, next_page_id=next_page_id)
async def count_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
return len(files)
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
timestamp_str = self._timestamp_to_str(event.timestamp)
kind = event.__class__.__name__
# Handle both UUID objects and string UUIDs
if isinstance(event.id, str):
id_hex = event.id.replace('-', '')
else:
id_hex = event.id.hex
return f'{timestamp_str}_{kind}_{id_hex}'
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
"""Convert timestamp to YYYYMMDDHHMMSS format."""
if isinstance(timestamp, str):
# Parse ISO format timestamp string
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
return dt.strftime('%Y%m%d%H%M%S')
return timestamp.strftime('%Y%m%d%H%M%S')
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
events = []
for file_path in file_paths:
event = self._load_event_from_file(file_path)
if event is not None:
events.append(event)
return events
def _load_event_from_file(self, filepath: Path) -> Event | None:
"""Load an event from a file."""
try:
json_data = filepath.read_text()
return Event.model_validate_json(json_data)
except Exception:
return None
def _get_event_files_by_pattern(
self, pattern: str, conversation_id: UUID | None = None
) -> list[Path]:
"""Get event files matching a glob pattern, sorted by timestamp."""
if conversation_id:
search_path = self.events_dir / str(conversation_id) / pattern
else:
search_path = self.events_dir / '*' / pattern
files = glob.glob(str(search_path))
return sorted([Path(f) for f in files])
def _parse_filename(self, filename: str) -> dict[str, str] | None:
"""Parse filename to extract timestamp, kind, and event_id."""
try:
parts = filename.split('_')
if len(parts) >= 3:
timestamp_str = parts[0]
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
event_id = parts[-1]
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
except Exception:
pass
return None
def _get_conversation_id(self, file: Path) -> UUID | None:
try:
return UUID(file.parent.name)
except Exception:
return None
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
result = set()
for file in files:
conversation_id = self._get_conversation_id(file)
if conversation_id:
result.add(conversation_id)
return result
@abstractmethod
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
"""Filter files by conversation."""
def _filter_files_by_criteria(
self,
files: list[Path],
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
) -> list[Path]:
"""Filter files based on search criteria."""
filtered_files = []
for file_path in files:
# Check conversation_id filter
if conversation_id__eq:
if str(conversation_id__eq) not in str(file_path):
continue
# Parse filename for additional filtering
filename_info = self._parse_filename(file_path.name)
if not filename_info:
continue
# Check kind filter
if kind__eq and filename_info['kind'] != kind__eq:
continue
# Check timestamp filters
if timestamp__gte or timestamp__lt:
try:
file_timestamp = datetime.strptime(
filename_info['timestamp'], '%Y%m%d%H%M%S'
)
if timestamp__gte and file_timestamp < timestamp__gte:
continue
if timestamp__lt and file_timestamp >= timestamp__lt:
continue
except ValueError:
continue
filtered_files.append(file_path)
return filtered_files

View File

@@ -29,3 +29,4 @@ class ConversationInfo:
pr_number: list[int] = field(default_factory=list)
conversation_version: str = 'V0'
sub_conversation_ids: list[str] = field(default_factory=list)
public: bool | None = None

View File

@@ -1501,4 +1501,5 @@ def _to_conversation_info(app_conversation: AppConversation) -> ConversationInfo
sub_conversation_ids=[
sub_id.hex for sub_id in app_conversation.sub_conversation_ids
],
public=app_conversation.public,
)

View File

@@ -39,3 +39,4 @@ class ConversationMetadata:
# V1 compatibility
sandbox_id: str | None = None
conversation_version: str | None = None
public: bool | None = None