ALL-4634: implement public conversation sharing feature (#12044)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2025-12-26 10:02:01 -07:00
committed by GitHub
parent cb8c1fa263
commit a829d10213
30 changed files with 2191 additions and 250 deletions

View File

@@ -0,0 +1,41 @@
"""add public column to conversation_metadata
Revision ID: 085
Revises: 084
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '085'
down_revision: Union[str, None] = '084'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('public', sa.Boolean(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_public'),
'conversation_metadata',
['public'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_public'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'public')

View File

@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-ai"
version = "0.0.0-post.5687+7853b41ad"
version = "0.0.0-post.5750+f19fb1043"
description = "OpenHands: Code Less, Make More"
optional = false
python-versions = "^3.12,<3.14"

View File

@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
from server.sharing.shared_conversation_router import ( # noqa: E402
router as shared_conversation_router,
)
from server.sharing.shared_event_router import ( # noqa: E402
router as shared_event_router,
)
from openhands.server.app import app as base_app # noqa: E402
from openhands.server.listen_socket import sio # noqa: E402
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
base_app.include_router(
billing_router
) # Add routes for credit management and Stripe payment integration
base_app.include_router(shared_conversation_router)
base_app.include_router(shared_event_router)
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
if GITHUB_APP_CLIENT_ID:
@@ -99,6 +107,7 @@ base_app.include_router(
event_webhook_router
) # Add routes for Events in nested runtimes
base_app.add_middleware(
CORSMiddleware,
allow_origins=PERMITTED_CORS_ORIGINS,

View File

@@ -0,0 +1,20 @@
# Sharing Package
This package contains functionality for sharing conversations.
## Components
- **shared.py**: Data models for shared conversations
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
- **shared_event_service.py**: Service interface for accessing shared events
- **shared_event_service_impl.py**: Implementation of the shared event service
- **shared_conversation_router.py**: REST API endpoints for shared conversations
- **shared_event_router.py**: REST API endpoints for shared events
## Features
- Read-only access to shared conversations
- Event access for shared conversations
- Search and filtering capabilities
- Pagination support

View File

@@ -0,0 +1,142 @@
"""Implementation of SharedEventService.
This implementation provides read-only access to events from shared conversations:
- Validates that the conversation is shared before returning events
- Uses existing EventService for actual event retrieval
- Uses SharedConversationInfoService for shared conversation validation
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_event_service import (
SharedEventService,
SharedEventServiceInjector,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
logger = logging.getLogger(__name__)
@dataclass
class SharedEventServiceImpl(SharedEventService):
"""Implementation of SharedEventService that validates shared access."""
shared_conversation_info_service: SharedConversationInfoService
event_service: EventService
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return None
# If conversation is shared, get the event
return await self.event_service.get_event(event_id)
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
# Return empty page if conversation is not shared
return EventPage(items=[], next_page_id=None)
# If conversation is shared, search events for this conversation
return await self.event_service.search_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return 0
# If conversation is shared, count events for this conversation
return await self.event_service.count_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
class SharedEventServiceImplInjector(SharedEventServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedEventService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import (
get_db_session,
get_event_service,
)
async with (
get_db_session(state, request) as db_session,
get_event_service(state, request) as event_service,
):
shared_conversation_info_service = SQLSharedConversationInfoService(
db_session=db_session
)
service = SharedEventServiceImpl(
shared_conversation_info_service=shared_conversation_info_service,
event_service=event_service,
)
yield service

View File

@@ -0,0 +1,66 @@
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from openhands.app_server.services.injector import Injector
from openhands.sdk.utils.models import DiscriminatedUnionMixin
class SharedConversationInfoService(ABC):
"""Service for accessing shared conversation info without user restrictions."""
@abstractmethod
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
@abstractmethod
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations."""
@abstractmethod
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single shared conversation info, returning None if missing or not shared."""
async def batch_get_shared_conversation_info(
self, conversation_ids: list[UUID]
) -> list[SharedConversation | None]:
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
return await asyncio.gather(
*[
self.get_shared_conversation_info(conversation_id)
for conversation_id in conversation_ids
]
)
class SharedConversationInfoServiceInjector(
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
):
pass

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from enum import Enum
# Simplified imports to avoid dependency chain issues
# from openhands.integrations.service_types import ProviderType
# from openhands.sdk.llm import MetricsSnapshot
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# For now, use Any to avoid import issues
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from openhands.agent_server.utils import OpenHandsUUID, utc_now
ProviderType = Any
MetricsSnapshot = Any
ConversationTrigger = Any
class SharedConversation(BaseModel):
"""Shared conversation info model with all fields from AppConversationInfo."""
id: OpenHandsUUID = Field(default_factory=uuid4)
created_by_user_id: str | None
sandbox_id: str
selected_repository: str | None = None
selected_branch: str | None = None
git_provider: ProviderType | None = None
title: str | None = None
pr_number: list[int] = Field(default_factory=list)
llm_model: str | None = None
metrics: MetricsSnapshot | None = None
parent_conversation_id: OpenHandsUUID | None = None
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
class SharedConversationSortOrder(Enum):
CREATED_AT = 'CREATED_AT'
CREATED_AT_DESC = 'CREATED_AT_DESC'
UPDATED_AT = 'UPDATED_AT'
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
TITLE = 'TITLE'
TITLE_DESC = 'TITLE_DESC'
class SharedConversationPage(BaseModel):
items: list[SharedConversation]
next_page_id: str | None = None

View File

@@ -0,0 +1,135 @@
"""Shared Conversation router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoServiceInjector,
)
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
shared_conversation_info_service_dependency = Depends(
SQLSharedConversationInfoServiceInjector().depends
)
# Read methods
@router.get('/search')
async def search_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
sort_order: Annotated[
SharedConversationSortOrder,
Query(title='Sort order for results'),
] = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(
title='The max number of results in the page',
gt=0,
lte=100,
),
] = 100,
include_sub_conversations: Annotated[
bool,
Query(
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
),
] = False,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> SharedConversationPage:
"""Search / List shared conversations."""
assert limit > 0
assert limit <= 100
return await shared_conversation_service.search_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
include_sub_conversations=include_sub_conversations,
)
@router.get('/count')
async def count_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> int:
"""Count shared conversations matching the given filters."""
return await shared_conversation_service.count_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
@router.get('')
async def batch_get_shared_conversations(
ids: Annotated[list[str], Query()],
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> list[SharedConversation | None]:
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
assert len(ids) <= 100
uuids = [UUID(id_) for id_ in ids]
shared_conversation_info = (
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
)
return shared_conversation_info

View File

@@ -0,0 +1,126 @@
"""Shared Event router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImplInjector,
)
from server.sharing.shared_event_service import SharedEventService
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
# Read methods
@router.get('/search')
async def search_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to search events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> EventPage:
"""Search / List events for a shared conversation."""
assert limit > 0
assert limit <= 100
return await shared_event_service.search_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
@router.get('/count')
async def count_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to count events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> int:
"""Count events for a shared conversation matching the given filters."""
return await shared_event_service.count_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
@router.get('')
async def batch_get_shared_events(
conversation_id: Annotated[
UUID,
Query(title='Conversation ID to get events for'),
],
id: Annotated[list[str], Query()],
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> list[Event | None]:
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
assert len(id) <= 100
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
return events
@router.get('/{conversation_id}/{event_id}')
async def get_shared_event(
conversation_id: UUID,
event_id: str,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> Event | None:
"""Get a single event from a shared conversation by conversation_id and event_id."""
return await shared_event_service.get_shared_event(conversation_id, event_id)

View File

@@ -0,0 +1,64 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import Injector
from openhands.sdk import Event
from openhands.sdk.utils.models import DiscriminatedUnionMixin
_logger = logging.getLogger(__name__)
class SharedEventService(ABC):
"""Event Service for getting events from shared conversations only."""
@abstractmethod
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
@abstractmethod
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
@abstractmethod
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
async def batch_get_shared_events(
self, conversation_id: UUID, event_ids: list[str]
) -> list[Event | None]:
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
return await asyncio.gather(
*[
self.get_shared_event(conversation_id, event_id)
for event_id in event_ids
]
)
class SharedEventServiceInjector(
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
):
pass

View File

@@ -0,0 +1,282 @@
"""SQL implementation of SharedConversationInfoService.
This implementation provides read-only access to shared conversations:
- Direct database access without user permission checks
- Filters only conversations marked as shared (currently public)
- Full async/await support using SQL async db_sessions
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
SharedConversationInfoServiceInjector,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.services.injector import InjectorState
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
logger = logging.getLogger(__name__)
@dataclass
class SQLSharedConversationInfoService(SharedConversationInfoService):
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
db_session: AsyncSession
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == SharedConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == SharedConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return SharedConversationPage(items=items, next_page_id=next_page_id)
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations matching the given filters."""
from sqlalchemy import func
query = select(func.count(StoredConversationMetadata.conversation_id))
# Only include shared conversations
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
return result.scalar() or 0
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
if stored is None:
return None
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
# Only include conversations marked as public
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _apply_filters(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply common filters to a query."""
if title__contains is not None:
query = query.where(
StoredConversationMetadata.title.contains(title__contains)
)
if created_at__gte is not None:
query = query.where(
StoredConversationMetadata.created_at >= created_at__gte
)
if created_at__lt is not None:
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
query = query.where(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
query = query.where(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
return query
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
# Rebuild token usage
token_usage = TokenUsage(
prompt_tokens=stored.prompt_tokens,
completion_tokens=stored.completion_tokens,
cache_read_tokens=stored.cache_read_tokens,
cache_write_tokens=stored.cache_write_tokens,
context_window=stored.context_window,
per_turn_token=stored.per_turn_token,
)
# Rebuild metrics object
metrics = MetricsSnapshot(
accumulated_cost=stored.accumulated_cost,
max_budget_per_task=stored.max_budget_per_task,
accumulated_token_usage=token_usage,
)
# Get timestamps
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
git_provider=(
ProviderType(stored.git_provider) if stored.git_provider else None
),
title=stored.title,
pr_number=stored.pr_number,
llm_model=stored.llm_model,
metrics=metrics,
parent_conversation_id=(
UUID(stored.parent_conversation_id)
if stored.parent_conversation_id
else None
),
sub_conversation_ids=sub_conversation_ids or [],
created_at=created_at,
updated_at=updated_at,
)
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not store timezones - and since we can't update the existing models
we assume UTC if the timezone is missing."""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedConversationInfoService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import get_db_session
async with get_db_session(state, request) as db_session:
service = SQLSharedConversationInfoService(db_session=db_session)
yield service

View File

@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
kwargs.pop('context_window', None)
kwargs.pop('per_turn_token', None)
kwargs.pop('parent_conversation_id', None)
kwargs.pop('public')
return ConversationMetadata(**kwargs)

View File

@@ -0,0 +1 @@
"""Tests for sharing package."""

View File

@@ -0,0 +1,91 @@
"""Tests for public conversation models."""
from datetime import datetime
from uuid import uuid4
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
def test_public_conversation_creation():
"""Test that SharedConversation can be created with all required fields."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
assert conversation.id == conversation_id
assert conversation.title == 'Test Conversation'
assert conversation.created_by_user_id == 'test_user'
assert conversation.sandbox_id == 'test_sandbox'
def test_public_conversation_page_creation():
"""Test that SharedConversationPage can be created."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
page = SharedConversationPage(
items=[conversation],
next_page_id='next_page',
)
assert len(page.items) == 1
assert page.items[0].id == conversation_id
assert page.next_page_id == 'next_page'
def test_public_conversation_sort_order_enum():
"""Test that SharedConversationSortOrder enum has expected values."""
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'TITLE')
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
def test_public_conversation_optional_fields():
"""Test that SharedConversation works with optional fields."""
conversation_id = uuid4()
parent_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository='owner/repo',
parent_conversation_id=parent_id,
llm_model='gpt-4',
)
assert conversation.selected_repository == 'owner/repo'
assert conversation.parent_conversation_id == parent_id
assert conversation.llm_model == 'gpt-4'

View File

@@ -0,0 +1,430 @@
"""Tests for SharedConversationInfoService."""
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import uuid4
import pytest
from server.sharing.shared_conversation_models import (
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
async def shared_conversation_info_service(async_session):
"""Create a SharedConversationInfoService for testing."""
return SQLSharedConversationInfoService(db_session=async_session)
@pytest.fixture
async def app_conversation_service(async_session):
"""Create an AppConversationInfoService for creating test data."""
return SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
def sample_conversation_info():
"""Create a sample conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
selected_repository='test/repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Test Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[123],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=1.5,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=150,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=True, # Make it public for testing
)
@pytest.fixture
def sample_private_conversation_info():
"""Create a sample private conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_private',
selected_repository='test/private_repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Private Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[124],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=2.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=300,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=False, # Make it private
)
class TestSharedConversationInfoService:
"""Test cases for SharedConversationInfoService."""
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_public_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test that get_shared_conversation_info returns a public conversation."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_conversation_info.id
)
assert result is not None
assert result.id == sample_conversation_info.id
assert result.title == sample_conversation_info.title
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_private_conversation_info,
):
"""Test that get_shared_conversation_info returns None for private conversations."""
# Create a private conversation
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Try to retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_private_conversation_info.id
)
assert result is None
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
self, shared_conversation_info_service
):
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
nonexistent_id = uuid4()
result = await shared_conversation_info_service.get_shared_conversation_info(
nonexistent_id
)
assert result is None
@pytest.mark.asyncio
async def test_search_shared_conversation_info_returns_only_public_conversations(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test that search only returns public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Search for all conversations
result = (
await shared_conversation_info_service.search_shared_conversation_info()
)
# Should only return the public conversation
assert len(result.items) == 1
assert result.items[0].id == sample_conversation_info.id
assert result.items[0].title == sample_conversation_info.title
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_title_filter(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test searching with title filter."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Search with matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='Test'
)
assert len(result.items) == 1
# Search with non-matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='NonExistent'
)
assert len(result.items) == 0
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_sort_order(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test searching with different sort orders."""
# Create multiple public conversations with different titles and timestamps
conv1 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_1',
title='A First Conversation',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv2 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_2',
title='B Second Conversation',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_conversation_service.save_app_conversation_info(conv1)
await app_conversation_service.save_app_conversation_info(conv2)
# Test sort by title ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE
)
assert len(result.items) == 2
assert result.items[0].title == 'A First Conversation'
assert result.items[1].title == 'B Second Conversation'
# Test sort by title descending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE_DESC
)
assert len(result.items) == 2
assert result.items[0].title == 'B Second Conversation'
assert result.items[1].title == 'A First Conversation'
# Test sort by created_at ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.items[0].id == conv1.id
assert result.items[1].id == conv2.id
# Test sort by created_at descending (default)
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
)
assert len(result.items) == 2
assert result.items[0].id == conv2.id
assert result.items[1].id == conv1.id
@pytest.mark.asyncio
async def test_count_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test counting public conversations."""
# Initially should be 0
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 0
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
# Create a private conversation - count should remain 1
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
@pytest.mark.asyncio
async def test_batch_get_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test batch getting public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Batch get both conversations
result = (
await shared_conversation_info_service.batch_get_shared_conversation_info(
[sample_conversation_info.id, sample_private_conversation_info.id]
)
)
# Should return the public one and None for the private one
assert len(result) == 2
assert result[0] is not None
assert result[0].id == sample_conversation_info.id
assert result[1] is None
@pytest.mark.asyncio
async def test_search_with_pagination(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test search with pagination."""
# Create multiple public conversations
conversations = []
for i in range(5):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id=f'test_sandbox_{i}',
title=f'Conversation {i}',
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conversations.append(conv)
await app_conversation_service.save_app_conversation_info(conv)
# Get first page with limit 2
result = await shared_conversation_info_service.search_shared_conversation_info(
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.next_page_id is not None
# Get next page
result2 = (
await shared_conversation_info_service.search_shared_conversation_info(
limit=2,
page_id=result.next_page_id,
sort_order=SharedConversationSortOrder.CREATED_AT,
)
)
assert len(result2.items) == 2
assert result2.next_page_id is not None
# Verify no overlap between pages
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)

View File

@@ -0,0 +1,365 @@
"""Tests for SharedEventService."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImpl,
)
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import SharedConversation
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
@pytest.fixture
def mock_shared_conversation_info_service():
"""Create a mock SharedConversationInfoService."""
return AsyncMock(spec=SharedConversationInfoService)
@pytest.fixture
def mock_event_service():
"""Create a mock EventService."""
return AsyncMock(spec=EventService)
@pytest.fixture
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
"""Create a SharedEventService for testing."""
return SharedEventServiceImpl(
shared_conversation_info_service=mock_shared_conversation_info_service,
event_service=mock_event_service,
)
@pytest.fixture
def sample_public_conversation():
"""Create a sample public conversation."""
return SharedConversation(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Public Conversation',
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
@pytest.fixture
def sample_event():
"""Create a sample event."""
# For testing purposes, we'll just use a mock that the EventPage can accept
# The actual event creation is complex and not the focus of these tests
return None
class TestSharedEventService:
"""Test cases for SharedEventService."""
async def test_get_shared_event_returns_event_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that get_shared_event returns an event for a public conversation."""
conversation_id = sample_public_conversation.id
event_id = 'test_event_id'
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return an event
mock_event_service.get_event.return_value = sample_event
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result == sample_event
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.get_event.assert_called_once_with(event_id)
async def test_get_shared_event_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that get_shared_event returns None for a private conversation."""
conversation_id = uuid4()
event_id = 'test_event_id'
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that search_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id=None)
mock_event_service.search_events.return_value = mock_event_page
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
limit=10,
)
# Verify the result
assert result == mock_event_page
assert len(result.items) == 0 # Empty list as we mocked
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
page_id=None,
limit=10,
)
async def test_search_shared_events_returns_empty_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that search_shared_events returns empty page for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
limit=10,
)
# Verify the result
assert isinstance(result, EventPage)
assert len(result.items) == 0
assert result.next_page_id is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.search_events.assert_not_called()
async def test_count_shared_events_returns_count_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test that count_shared_events returns count for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return a count
mock_event_service.count_events.return_value = 5
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
)
# Verify the result
assert result == 5
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.count_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
)
async def test_count_shared_events_returns_zero_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that count_shared_events returns 0 for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
)
# Verify the result
assert result == 0
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.count_events.assert_not_called()
async def test_batch_get_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that batch_get_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
event_ids = ['event1', 'event2']
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_service.get_event.side_effect = [sample_event, None]
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] == sample_event
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Verify that get_event was called for each event
assert mock_event_service.get_event.call_count == 2
async def test_batch_get_shared_events_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that batch_get_shared_events returns None for a private conversation."""
conversation_id = uuid4()
event_ids = ['event1', 'event2']
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] is None
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_with_all_parameters(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test search_shared_events with all parameters."""
conversation_id = sample_public_conversation.id
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id='next_page')
mock_event_service.search_events.return_value = mock_event_page
# Call the method with all parameters
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)
# Verify the result
assert result == mock_event_page
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)

View File

@@ -45,6 +45,8 @@ class AppConversationInfo(BaseModel):
parent_conversation_id: OpenHandsUUID | None = None
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
public: bool | None = None
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
@@ -114,6 +116,12 @@ class AppConversationStartRequest(BaseModel):
parent_conversation_id: OpenHandsUUID | None = None
agent_type: AgentType = Field(default=AgentType.DEFAULT)
public: bool | None = None
class AppConversationUpdateRequest(BaseModel):
public: bool
class AppConversationStartTaskStatus(Enum):
WORKING = 'WORKING'

View File

@@ -40,6 +40,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTask,
AppConversationStartTaskPage,
AppConversationStartTaskSortOrder,
AppConversationUpdateRequest,
SkillResponse,
)
from openhands.app_server.app_conversation.app_conversation_service import (
@@ -222,6 +223,22 @@ async def start_app_conversation(
raise
@router.patch('/{conversation_id}')
async def update_app_conversation(
conversation_id: str,
update_request: AppConversationUpdateRequest,
app_conversation_service: AppConversationService = (
app_conversation_service_dependency
),
) -> AppConversation:
info = await app_conversation_service.update_app_conversation(
UUID(conversation_id), update_request
)
if info is None:
raise HTTPException(404, 'unknown_app_conversation')
return info
@router.post('/stream-start')
async def stream_app_conversation_start(
request: AppConversationStartRequest,

View File

@@ -10,6 +10,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationSortOrder,
AppConversationStartRequest,
AppConversationStartTask,
AppConversationUpdateRequest,
)
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.services.injector import Injector
@@ -98,6 +99,13 @@ class AppConversationService(ABC):
"""Run the setup scripts for the project and yield status updates"""
yield task
@abstractmethod
async def update_app_conversation(
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist."""
@abstractmethod
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
"""Delete a V1 conversation and all its associated data.

View File

@@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
AppConversationStartTask,
AppConversationStartTaskStatus,
AppConversationUpdateRequest,
)
from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
@@ -1049,6 +1050,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
f'Successfully updated agent-server conversation {conversation_id} title to "{new_title}"'
)
async def update_app_conversation(
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist."""
info = await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
if info is None:
return None
for field_name in request.model_fields:
value = getattr(request, field_name)
setattr(info, field_name, value)
info = await self.app_conversation_info_service.save_app_conversation_info(info)
conversations = await self._build_app_conversations([info])
return conversations[0]
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
"""Delete a V1 conversation and all its associated data.

View File

@@ -25,7 +25,17 @@ from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import Column, DateTime, Float, Integer, Select, String, func, select
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
Integer,
Select,
String,
func,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.agent_server.utils import utc_now
@@ -91,6 +101,7 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_version = Column(String, nullable=False, default='V0', index=True)
sandbox_id = Column(String, nullable=True, index=True)
parent_conversation_id = Column(String, nullable=True, index=True)
public = Column(Boolean, nullable=True, index=True)
@dataclass
@@ -350,6 +361,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
if info.parent_conversation_id
else None
),
public=info.public,
)
await self.db_session.merge(stored)
@@ -541,6 +553,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
else None
),
sub_conversation_ids=sub_conversation_ids or [],
public=stored.public,
created_at=created_at,
updated_at=updated_at,
)

View File

@@ -0,0 +1,41 @@
"""add public column to conversation_metadata
Revision ID: 004
Revises: 003
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '004'
down_revision: Union[str, None] = '003'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('public', sa.Boolean(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_public'),
'conversation_metadata',
['public'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_public'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'public')

View File

@@ -22,7 +22,7 @@ event_service_dependency = depends_event_service()
@router.get('/search')
async def search_events(
conversation_id__eq: Annotated[
UUID | None,
str | None,
Query(title='Optional filter by conversation ID'),
] = None,
kind__eq: Annotated[
@@ -55,7 +55,7 @@ async def search_events(
assert limit > 0
assert limit <= 100
return await event_service.search_events(
conversation_id__eq=conversation_id__eq,
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
@@ -68,7 +68,7 @@ async def search_events(
@router.get('/count')
async def count_events(
conversation_id__eq: Annotated[
UUID | None,
str | None,
Query(title='Optional filter by conversation ID'),
] = None,
kind__eq: Annotated[
@@ -91,7 +91,7 @@ async def count_events(
) -> int:
"""Count events matching the given filters."""
return await event_service.count_events(
conversation_id__eq=conversation_id__eq,
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,

View File

@@ -1,32 +1,27 @@
"""Filesystem-based EventService implementation."""
import asyncio
import glob
import json
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.app_conversation.app_conversation_info_service import (
AppConversationInfoService,
)
from openhands.app_server.errors import OpenHandsError
from openhands.app_server.event.event_service import EventService, EventServiceInjector
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.event.filesystem_event_service_base import (
FilesystemEventServiceBase,
)
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
_logger = logging.getLogger(__name__)
@dataclass
class FilesystemEventService(EventService):
class FilesystemEventService(FilesystemEventServiceBase, EventService):
"""Filesystem-based implementation of EventService.
Events are stored in files with the naming format:
@@ -47,25 +42,6 @@ class FilesystemEventService(EventService):
events_path.mkdir(parents=True, exist_ok=True)
return events_path
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
"""Convert timestamp to YYYYMMDDHHMMSS format."""
if isinstance(timestamp, str):
# Parse ISO format timestamp string
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
return dt.strftime('%Y%m%d%H%M%S')
return timestamp.strftime('%Y%m%d%H%M%S')
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
timestamp_str = self._timestamp_to_str(event.timestamp)
kind = event.__class__.__name__
# Handle both UUID objects and string UUIDs
if isinstance(event.id, str):
id_hex = event.id.replace('-', '')
else:
id_hex = event.id.hex
return f'{timestamp_str}_{kind}_{id_hex}'
def _save_event_to_file(self, conversation_id: UUID, event: Event) -> None:
"""Save an event to a file."""
events_path = self._ensure_events_dir(conversation_id)
@@ -77,60 +53,17 @@ class FilesystemEventService(EventService):
data = event.model_dump(mode='json')
f.write(json.dumps(data, indent=2))
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
events = []
for file_path in file_paths:
event = self._load_event_from_file(file_path)
if event is not None:
events.append(event)
return events
def _load_event_from_file(self, filepath: Path) -> Event | None:
"""Load an event from a file."""
try:
json_data = filepath.read_text()
return Event.model_validate_json(json_data)
except Exception:
return None
def _get_event_files_by_pattern(
self, pattern: str, conversation_id: UUID | None = None
) -> list[Path]:
"""Get event files matching a glob pattern, sorted by timestamp."""
if conversation_id:
search_path = self.events_dir / str(conversation_id) / pattern
else:
search_path = self.events_dir / '*' / pattern
files = glob.glob(str(search_path))
return sorted([Path(f) for f in files])
def _parse_filename(self, filename: str) -> dict[str, str] | None:
"""Parse filename to extract timestamp, kind, and event_id."""
try:
parts = filename.split('_')
if len(parts) >= 3:
timestamp_str = parts[0]
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
event_id = parts[-1]
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
except Exception:
pass
return None
def _get_conversation_id(self, file: Path) -> UUID | None:
try:
return UUID(file.parent.name)
except Exception:
return None
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
result = set()
for file in files:
conversation_id = self._get_conversation_id(file)
if conversation_id:
result.add(conversation_id)
return result
async def save_event(self, conversation_id: UUID, event: Event):
"""Save an event. Internal method intended not be part of the REST api."""
conversation = (
await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
)
if not conversation:
# This is either an illegal state or somebody is trying to hack
raise OpenHandsError('No such conversation: {conversaiont_id}')
self._save_event_to_file(conversation_id, event)
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
conversation_ids = list(self._get_conversation_ids(files))
@@ -150,161 +83,6 @@ class FilesystemEventService(EventService):
]
return result
def _filter_files_by_criteria(
self,
files: list[Path],
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
) -> list[Path]:
"""Filter files based on search criteria."""
filtered_files = []
for file_path in files:
# Check conversation_id filter
if conversation_id__eq:
if str(conversation_id__eq) not in str(file_path):
continue
# Parse filename for additional filtering
filename_info = self._parse_filename(file_path.name)
if not filename_info:
continue
# Check kind filter
if kind__eq and filename_info['kind'] != kind__eq:
continue
# Check timestamp filters
if timestamp__gte or timestamp__lt:
try:
file_timestamp = datetime.strptime(
filename_info['timestamp'], '%Y%m%d%H%M%S'
)
if timestamp__gte and file_timestamp < timestamp__gte:
continue
if timestamp__lt and file_timestamp >= timestamp__lt:
continue
except ValueError:
continue
filtered_files.append(file_path)
return filtered_files
async def get_event(self, event_id: str) -> Event | None:
"""Get the event with the given id, or None if not found."""
# Convert event_id to hex format (remove dashes) for filename matching
if isinstance(event_id, str) and '-' in event_id:
id_hex = event_id.replace('-', '')
else:
id_hex = event_id
# Use glob pattern to find files ending with the event_id
pattern = f'*_{id_hex}'
files = self._get_event_files_by_pattern(pattern)
if not files:
return None
# If there is no access to the conversation do not return the event
file = files[0]
conversation_id = self._get_conversation_id(file)
if not conversation_id:
return None
conversation = (
await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
)
if not conversation:
return None
# Load and return the first matching event
return self._load_event_from_file(file)
async def search_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search for events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
files.sort(
key=lambda f: f.name,
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
)
# Handle pagination
start_index = 0
if page_id:
for i, file_path in enumerate(files):
if file_path.name == page_id:
start_index = i + 1
break
# Collect items for this page
page_files = files[start_index : start_index + limit]
next_page_id = None
if start_index + limit < len(files):
next_page_id = files[start_index + limit].name
# Load all events from files in a background thread.
loop = asyncio.get_running_loop()
page_events = await loop.run_in_executor(
None, self._load_events_from_files, page_files
)
return EventPage(items=page_events, next_page_id=next_page_id)
async def count_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
return len(files)
async def save_event(self, conversation_id: UUID, event: Event):
"""Save an event. Internal method intended not be part of the REST api."""
conversation = (
await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
)
if not conversation:
# This is either an illegal state or somebody is trying to hack
raise OpenHandsError('No such conversation: {conversaiont_id}')
self._save_event_to_file(conversation_id, event)
class FilesystemEventServiceInjector(EventServiceInjector):
async def inject(

View File

@@ -0,0 +1,224 @@
import asyncio
import glob
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from uuid import UUID
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
class FilesystemEventServiceBase:
events_dir: Path
async def get_event(self, event_id: str) -> Event | None:
"""Get the event with the given id, or None if not found."""
# Convert event_id to hex format (remove dashes) for filename matching
if isinstance(event_id, str) and '-' in event_id:
id_hex = event_id.replace('-', '')
else:
id_hex = event_id
# Use glob pattern to find files ending with the event_id
pattern = f'*_{id_hex}'
files = self._get_event_files_by_pattern(pattern)
files = await self._filter_files_by_conversation(files)
if not files:
return None
# Load and return the first matching event
return self._load_event_from_file(files[0])
async def search_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search for events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
files.sort(
key=lambda f: f.name,
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
)
# Handle pagination
start_index = 0
if page_id:
for i, file_path in enumerate(files):
if file_path.name == page_id:
start_index = i + 1
break
# Collect items for this page
page_files = files[start_index : start_index + limit]
next_page_id = None
if start_index + limit < len(files):
next_page_id = files[start_index + limit].name
# Load all events from files in a background thread.
loop = asyncio.get_running_loop()
page_events = await loop.run_in_executor(
None, self._load_events_from_files, page_files
)
return EventPage(items=page_events, next_page_id=next_page_id)
async def count_events(
self,
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events matching the given filters."""
# Build the search pattern
pattern = '*'
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
files = await self._filter_files_by_conversation(files)
files = self._filter_files_by_criteria(
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
)
return len(files)
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
timestamp_str = self._timestamp_to_str(event.timestamp)
kind = event.__class__.__name__
# Handle both UUID objects and string UUIDs
if isinstance(event.id, str):
id_hex = event.id.replace('-', '')
else:
id_hex = event.id.hex
return f'{timestamp_str}_{kind}_{id_hex}'
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
"""Convert timestamp to YYYYMMDDHHMMSS format."""
if isinstance(timestamp, str):
# Parse ISO format timestamp string
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
return dt.strftime('%Y%m%d%H%M%S')
return timestamp.strftime('%Y%m%d%H%M%S')
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
events = []
for file_path in file_paths:
event = self._load_event_from_file(file_path)
if event is not None:
events.append(event)
return events
def _load_event_from_file(self, filepath: Path) -> Event | None:
"""Load an event from a file."""
try:
json_data = filepath.read_text()
return Event.model_validate_json(json_data)
except Exception:
return None
def _get_event_files_by_pattern(
self, pattern: str, conversation_id: UUID | None = None
) -> list[Path]:
"""Get event files matching a glob pattern, sorted by timestamp."""
if conversation_id:
search_path = self.events_dir / str(conversation_id) / pattern
else:
search_path = self.events_dir / '*' / pattern
files = glob.glob(str(search_path))
return sorted([Path(f) for f in files])
def _parse_filename(self, filename: str) -> dict[str, str] | None:
"""Parse filename to extract timestamp, kind, and event_id."""
try:
parts = filename.split('_')
if len(parts) >= 3:
timestamp_str = parts[0]
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
event_id = parts[-1]
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
except Exception:
pass
return None
def _get_conversation_id(self, file: Path) -> UUID | None:
try:
return UUID(file.parent.name)
except Exception:
return None
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
result = set()
for file in files:
conversation_id = self._get_conversation_id(file)
if conversation_id:
result.add(conversation_id)
return result
@abstractmethod
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
"""Filter files by conversation."""
def _filter_files_by_criteria(
self,
files: list[Path],
conversation_id__eq: UUID | None = None,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
) -> list[Path]:
"""Filter files based on search criteria."""
filtered_files = []
for file_path in files:
# Check conversation_id filter
if conversation_id__eq:
if str(conversation_id__eq) not in str(file_path):
continue
# Parse filename for additional filtering
filename_info = self._parse_filename(file_path.name)
if not filename_info:
continue
# Check kind filter
if kind__eq and filename_info['kind'] != kind__eq:
continue
# Check timestamp filters
if timestamp__gte or timestamp__lt:
try:
file_timestamp = datetime.strptime(
filename_info['timestamp'], '%Y%m%d%H%M%S'
)
if timestamp__gte and file_timestamp < timestamp__gte:
continue
if timestamp__lt and file_timestamp >= timestamp__lt:
continue
except ValueError:
continue
filtered_files.append(file_path)
return filtered_files

View File

@@ -29,3 +29,4 @@ class ConversationInfo:
pr_number: list[int] = field(default_factory=list)
conversation_version: str = 'V0'
sub_conversation_ids: list[str] = field(default_factory=list)
public: bool | None = None

View File

@@ -1501,4 +1501,5 @@ def _to_conversation_info(app_conversation: AppConversation) -> ConversationInfo
sub_conversation_ids=[
sub_id.hex for sub_id in app_conversation.sub_conversation_ids
],
public=app_conversation.public,
)

View File

@@ -39,3 +39,4 @@ class ConversationMetadata:
# V1 compatibility
sandbox_id: str | None = None
conversation_version: str | None = None
public: bool | None = None

15
poetry.lock generated
View File

@@ -12707,18 +12707,19 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests
[[package]]
name = "pytest-asyncio"
version = "1.1.0"
version = "1.3.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.9"
groups = ["test"]
python-versions = ">=3.10"
groups = ["dev", "test"]
files = [
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
{file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"},
{file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"},
]
[package.dependencies]
pytest = ">=8.2,<9"
pytest = ">=8.2,<10"
typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""}
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
@@ -16823,4 +16824,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "9764f3b69ec8ed35feebd78a826bbc6bfa4ac6d5b56bc999be8bc738b644e538"
content-hash = "e24ceb52bccd0c80f52c408215ccf007475eb69e10b895053ea49c7e3e4be3b8"

View File

@@ -139,6 +139,7 @@ pre-commit = "4.2.0"
build = "*"
types-setuptools = "*"
pytest = "^8.4.0"
pytest-asyncio = "^1.3.0"
[tool.poetry.group.test]
optional = true