ALL-4634: implement public conversation sharing feature (#12044)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2025-12-26 10:02:01 -07:00
committed by GitHub
parent cb8c1fa263
commit a829d10213
30 changed files with 2191 additions and 250 deletions

View File

@@ -0,0 +1,41 @@
"""add public column to conversation_metadata
Revision ID: 085
Revises: 084
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '085'
down_revision: Union[str, None] = '084'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('public', sa.Boolean(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_public'),
'conversation_metadata',
['public'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_public'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'public')

View File

@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-ai"
version = "0.0.0-post.5687+7853b41ad"
version = "0.0.0-post.5750+f19fb1043"
description = "OpenHands: Code Less, Make More"
optional = false
python-versions = "^3.12,<3.14"

View File

@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
from server.sharing.shared_conversation_router import ( # noqa: E402
router as shared_conversation_router,
)
from server.sharing.shared_event_router import ( # noqa: E402
router as shared_event_router,
)
from openhands.server.app import app as base_app # noqa: E402
from openhands.server.listen_socket import sio # noqa: E402
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
base_app.include_router(
billing_router
) # Add routes for credit management and Stripe payment integration
base_app.include_router(shared_conversation_router)
base_app.include_router(shared_event_router)
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
if GITHUB_APP_CLIENT_ID:
@@ -99,6 +107,7 @@ base_app.include_router(
event_webhook_router
) # Add routes for Events in nested runtimes
base_app.add_middleware(
CORSMiddleware,
allow_origins=PERMITTED_CORS_ORIGINS,

View File

@@ -0,0 +1,20 @@
# Sharing Package
This package contains functionality for sharing conversations.
## Components
- **shared.py**: Data models for shared conversations
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
- **shared_event_service.py**: Service interface for accessing shared events
- **shared_event_service_impl.py**: Implementation of the shared event service
- **shared_conversation_router.py**: REST API endpoints for shared conversations
- **shared_event_router.py**: REST API endpoints for shared events
## Features
- Read-only access to shared conversations
- Event access for shared conversations
- Search and filtering capabilities
- Pagination support

View File

@@ -0,0 +1,142 @@
"""Implementation of SharedEventService.
This implementation provides read-only access to events from shared conversations:
- Validates that the conversation is shared before returning events
- Uses existing EventService for actual event retrieval
- Uses SharedConversationInfoService for shared conversation validation
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_event_service import (
SharedEventService,
SharedEventServiceInjector,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
logger = logging.getLogger(__name__)
@dataclass
class SharedEventServiceImpl(SharedEventService):
"""Implementation of SharedEventService that validates shared access."""
shared_conversation_info_service: SharedConversationInfoService
event_service: EventService
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return None
# If conversation is shared, get the event
return await self.event_service.get_event(event_id)
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
# Return empty page if conversation is not shared
return EventPage(items=[], next_page_id=None)
# If conversation is shared, search events for this conversation
return await self.event_service.search_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return 0
# If conversation is shared, count events for this conversation
return await self.event_service.count_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
class SharedEventServiceImplInjector(SharedEventServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedEventService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import (
get_db_session,
get_event_service,
)
async with (
get_db_session(state, request) as db_session,
get_event_service(state, request) as event_service,
):
shared_conversation_info_service = SQLSharedConversationInfoService(
db_session=db_session
)
service = SharedEventServiceImpl(
shared_conversation_info_service=shared_conversation_info_service,
event_service=event_service,
)
yield service

View File

@@ -0,0 +1,66 @@
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from openhands.app_server.services.injector import Injector
from openhands.sdk.utils.models import DiscriminatedUnionMixin
class SharedConversationInfoService(ABC):
"""Service for accessing shared conversation info without user restrictions."""
@abstractmethod
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
@abstractmethod
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations."""
@abstractmethod
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single shared conversation info, returning None if missing or not shared."""
async def batch_get_shared_conversation_info(
self, conversation_ids: list[UUID]
) -> list[SharedConversation | None]:
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
return await asyncio.gather(
*[
self.get_shared_conversation_info(conversation_id)
for conversation_id in conversation_ids
]
)
class SharedConversationInfoServiceInjector(
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
):
pass

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from enum import Enum
# Simplified imports to avoid dependency chain issues
# from openhands.integrations.service_types import ProviderType
# from openhands.sdk.llm import MetricsSnapshot
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# For now, use Any to avoid import issues
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from openhands.agent_server.utils import OpenHandsUUID, utc_now
ProviderType = Any
MetricsSnapshot = Any
ConversationTrigger = Any
class SharedConversation(BaseModel):
"""Shared conversation info model with all fields from AppConversationInfo."""
id: OpenHandsUUID = Field(default_factory=uuid4)
created_by_user_id: str | None
sandbox_id: str
selected_repository: str | None = None
selected_branch: str | None = None
git_provider: ProviderType | None = None
title: str | None = None
pr_number: list[int] = Field(default_factory=list)
llm_model: str | None = None
metrics: MetricsSnapshot | None = None
parent_conversation_id: OpenHandsUUID | None = None
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
class SharedConversationSortOrder(Enum):
CREATED_AT = 'CREATED_AT'
CREATED_AT_DESC = 'CREATED_AT_DESC'
UPDATED_AT = 'UPDATED_AT'
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
TITLE = 'TITLE'
TITLE_DESC = 'TITLE_DESC'
class SharedConversationPage(BaseModel):
items: list[SharedConversation]
next_page_id: str | None = None

View File

@@ -0,0 +1,135 @@
"""Shared Conversation router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoServiceInjector,
)
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
shared_conversation_info_service_dependency = Depends(
SQLSharedConversationInfoServiceInjector().depends
)
# Read methods
@router.get('/search')
async def search_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
sort_order: Annotated[
SharedConversationSortOrder,
Query(title='Sort order for results'),
] = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(
title='The max number of results in the page',
gt=0,
lte=100,
),
] = 100,
include_sub_conversations: Annotated[
bool,
Query(
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
),
] = False,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> SharedConversationPage:
"""Search / List shared conversations."""
assert limit > 0
assert limit <= 100
return await shared_conversation_service.search_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
include_sub_conversations=include_sub_conversations,
)
@router.get('/count')
async def count_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> int:
"""Count shared conversations matching the given filters."""
return await shared_conversation_service.count_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
@router.get('')
async def batch_get_shared_conversations(
ids: Annotated[list[str], Query()],
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> list[SharedConversation | None]:
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
assert len(ids) <= 100
uuids = [UUID(id_) for id_ in ids]
shared_conversation_info = (
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
)
return shared_conversation_info

View File

@@ -0,0 +1,126 @@
"""Shared Event router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImplInjector,
)
from server.sharing.shared_event_service import SharedEventService
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
# Read methods
@router.get('/search')
async def search_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to search events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> EventPage:
"""Search / List events for a shared conversation."""
assert limit > 0
assert limit <= 100
return await shared_event_service.search_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
@router.get('/count')
async def count_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to count events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> int:
"""Count events for a shared conversation matching the given filters."""
return await shared_event_service.count_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
@router.get('')
async def batch_get_shared_events(
conversation_id: Annotated[
UUID,
Query(title='Conversation ID to get events for'),
],
id: Annotated[list[str], Query()],
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> list[Event | None]:
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
assert len(id) <= 100
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
return events
@router.get('/{conversation_id}/{event_id}')
async def get_shared_event(
conversation_id: UUID,
event_id: str,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> Event | None:
"""Get a single event from a shared conversation by conversation_id and event_id."""
return await shared_event_service.get_shared_event(conversation_id, event_id)

View File

@@ -0,0 +1,64 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import Injector
from openhands.sdk import Event
from openhands.sdk.utils.models import DiscriminatedUnionMixin
_logger = logging.getLogger(__name__)
class SharedEventService(ABC):
"""Event Service for getting events from shared conversations only."""
@abstractmethod
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
@abstractmethod
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
@abstractmethod
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
async def batch_get_shared_events(
self, conversation_id: UUID, event_ids: list[str]
) -> list[Event | None]:
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
return await asyncio.gather(
*[
self.get_shared_event(conversation_id, event_id)
for event_id in event_ids
]
)
class SharedEventServiceInjector(
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
):
pass

View File

@@ -0,0 +1,282 @@
"""SQL implementation of SharedConversationInfoService.
This implementation provides read-only access to shared conversations:
- Direct database access without user permission checks
- Filters only conversations marked as shared (currently public)
- Full async/await support using SQL async db_sessions
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
SharedConversationInfoServiceInjector,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.services.injector import InjectorState
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
logger = logging.getLogger(__name__)
@dataclass
class SQLSharedConversationInfoService(SharedConversationInfoService):
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
db_session: AsyncSession
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == SharedConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == SharedConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return SharedConversationPage(items=items, next_page_id=next_page_id)
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations matching the given filters."""
from sqlalchemy import func
query = select(func.count(StoredConversationMetadata.conversation_id))
# Only include shared conversations
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
return result.scalar() or 0
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
if stored is None:
return None
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
# Only include conversations marked as public
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _apply_filters(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply common filters to a query."""
if title__contains is not None:
query = query.where(
StoredConversationMetadata.title.contains(title__contains)
)
if created_at__gte is not None:
query = query.where(
StoredConversationMetadata.created_at >= created_at__gte
)
if created_at__lt is not None:
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
query = query.where(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
query = query.where(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
return query
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
# Rebuild token usage
token_usage = TokenUsage(
prompt_tokens=stored.prompt_tokens,
completion_tokens=stored.completion_tokens,
cache_read_tokens=stored.cache_read_tokens,
cache_write_tokens=stored.cache_write_tokens,
context_window=stored.context_window,
per_turn_token=stored.per_turn_token,
)
# Rebuild metrics object
metrics = MetricsSnapshot(
accumulated_cost=stored.accumulated_cost,
max_budget_per_task=stored.max_budget_per_task,
accumulated_token_usage=token_usage,
)
# Get timestamps
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
git_provider=(
ProviderType(stored.git_provider) if stored.git_provider else None
),
title=stored.title,
pr_number=stored.pr_number,
llm_model=stored.llm_model,
metrics=metrics,
parent_conversation_id=(
UUID(stored.parent_conversation_id)
if stored.parent_conversation_id
else None
),
sub_conversation_ids=sub_conversation_ids or [],
created_at=created_at,
updated_at=updated_at,
)
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not store timezones - and since we can't update the existing models
we assume UTC if the timezone is missing."""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedConversationInfoService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import get_db_session
async with get_db_session(state, request) as db_session:
service = SQLSharedConversationInfoService(db_session=db_session)
yield service

View File

@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
kwargs.pop('context_window', None)
kwargs.pop('per_turn_token', None)
kwargs.pop('parent_conversation_id', None)
kwargs.pop('public')
return ConversationMetadata(**kwargs)

View File

@@ -0,0 +1 @@
"""Tests for sharing package."""

View File

@@ -0,0 +1,91 @@
"""Tests for public conversation models."""
from datetime import datetime
from uuid import uuid4
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
def test_public_conversation_creation():
"""Test that SharedConversation can be created with all required fields."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
assert conversation.id == conversation_id
assert conversation.title == 'Test Conversation'
assert conversation.created_by_user_id == 'test_user'
assert conversation.sandbox_id == 'test_sandbox'
def test_public_conversation_page_creation():
"""Test that SharedConversationPage can be created."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
page = SharedConversationPage(
items=[conversation],
next_page_id='next_page',
)
assert len(page.items) == 1
assert page.items[0].id == conversation_id
assert page.next_page_id == 'next_page'
def test_public_conversation_sort_order_enum():
"""Test that SharedConversationSortOrder enum has expected values."""
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'TITLE')
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
def test_public_conversation_optional_fields():
"""Test that SharedConversation works with optional fields."""
conversation_id = uuid4()
parent_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository='owner/repo',
parent_conversation_id=parent_id,
llm_model='gpt-4',
)
assert conversation.selected_repository == 'owner/repo'
assert conversation.parent_conversation_id == parent_id
assert conversation.llm_model == 'gpt-4'

View File

@@ -0,0 +1,430 @@
"""Tests for SharedConversationInfoService."""
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import uuid4
import pytest
from server.sharing.shared_conversation_models import (
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
async def shared_conversation_info_service(async_session):
"""Create a SharedConversationInfoService for testing."""
return SQLSharedConversationInfoService(db_session=async_session)
@pytest.fixture
async def app_conversation_service(async_session):
"""Create an AppConversationInfoService for creating test data."""
return SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
def sample_conversation_info():
"""Create a sample conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
selected_repository='test/repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Test Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[123],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=1.5,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=150,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=True, # Make it public for testing
)
@pytest.fixture
def sample_private_conversation_info():
"""Create a sample private conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_private',
selected_repository='test/private_repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Private Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[124],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=2.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=300,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=False, # Make it private
)
class TestSharedConversationInfoService:
"""Test cases for SharedConversationInfoService."""
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_public_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test that get_shared_conversation_info returns a public conversation."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_conversation_info.id
)
assert result is not None
assert result.id == sample_conversation_info.id
assert result.title == sample_conversation_info.title
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_private_conversation_info,
):
"""Test that get_shared_conversation_info returns None for private conversations."""
# Create a private conversation
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Try to retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_private_conversation_info.id
)
assert result is None
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
self, shared_conversation_info_service
):
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
nonexistent_id = uuid4()
result = await shared_conversation_info_service.get_shared_conversation_info(
nonexistent_id
)
assert result is None
@pytest.mark.asyncio
async def test_search_shared_conversation_info_returns_only_public_conversations(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test that search only returns public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Search for all conversations
result = (
await shared_conversation_info_service.search_shared_conversation_info()
)
# Should only return the public conversation
assert len(result.items) == 1
assert result.items[0].id == sample_conversation_info.id
assert result.items[0].title == sample_conversation_info.title
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_title_filter(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test searching with title filter."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Search with matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='Test'
)
assert len(result.items) == 1
# Search with non-matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='NonExistent'
)
assert len(result.items) == 0
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_sort_order(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test searching with different sort orders."""
# Create multiple public conversations with different titles and timestamps
conv1 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_1',
title='A First Conversation',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv2 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_2',
title='B Second Conversation',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_conversation_service.save_app_conversation_info(conv1)
await app_conversation_service.save_app_conversation_info(conv2)
# Test sort by title ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE
)
assert len(result.items) == 2
assert result.items[0].title == 'A First Conversation'
assert result.items[1].title == 'B Second Conversation'
# Test sort by title descending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE_DESC
)
assert len(result.items) == 2
assert result.items[0].title == 'B Second Conversation'
assert result.items[1].title == 'A First Conversation'
# Test sort by created_at ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.items[0].id == conv1.id
assert result.items[1].id == conv2.id
# Test sort by created_at descending (default)
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
)
assert len(result.items) == 2
assert result.items[0].id == conv2.id
assert result.items[1].id == conv1.id
@pytest.mark.asyncio
async def test_count_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test counting public conversations."""
# Initially should be 0
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 0
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
# Create a private conversation - count should remain 1
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
@pytest.mark.asyncio
async def test_batch_get_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test batch getting public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Batch get both conversations
result = (
await shared_conversation_info_service.batch_get_shared_conversation_info(
[sample_conversation_info.id, sample_private_conversation_info.id]
)
)
# Should return the public one and None for the private one
assert len(result) == 2
assert result[0] is not None
assert result[0].id == sample_conversation_info.id
assert result[1] is None
@pytest.mark.asyncio
async def test_search_with_pagination(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test search with pagination."""
# Create multiple public conversations
conversations = []
for i in range(5):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id=f'test_sandbox_{i}',
title=f'Conversation {i}',
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conversations.append(conv)
await app_conversation_service.save_app_conversation_info(conv)
# Get first page with limit 2
result = await shared_conversation_info_service.search_shared_conversation_info(
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.next_page_id is not None
# Get next page
result2 = (
await shared_conversation_info_service.search_shared_conversation_info(
limit=2,
page_id=result.next_page_id,
sort_order=SharedConversationSortOrder.CREATED_AT,
)
)
assert len(result2.items) == 2
assert result2.next_page_id is not None
# Verify no overlap between pages
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)

View File

@@ -0,0 +1,365 @@
"""Tests for SharedEventService."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImpl,
)
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import SharedConversation
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
@pytest.fixture
def mock_shared_conversation_info_service():
"""Create a mock SharedConversationInfoService."""
return AsyncMock(spec=SharedConversationInfoService)
@pytest.fixture
def mock_event_service():
"""Create a mock EventService."""
return AsyncMock(spec=EventService)
@pytest.fixture
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
"""Create a SharedEventService for testing."""
return SharedEventServiceImpl(
shared_conversation_info_service=mock_shared_conversation_info_service,
event_service=mock_event_service,
)
@pytest.fixture
def sample_public_conversation():
"""Create a sample public conversation."""
return SharedConversation(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Public Conversation',
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
@pytest.fixture
def sample_event():
"""Create a sample event."""
# For testing purposes, we'll just use a mock that the EventPage can accept
# The actual event creation is complex and not the focus of these tests
return None
class TestSharedEventService:
"""Test cases for SharedEventService."""
async def test_get_shared_event_returns_event_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that get_shared_event returns an event for a public conversation."""
conversation_id = sample_public_conversation.id
event_id = 'test_event_id'
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return an event
mock_event_service.get_event.return_value = sample_event
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result == sample_event
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.get_event.assert_called_once_with(event_id)
async def test_get_shared_event_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that get_shared_event returns None for a private conversation."""
conversation_id = uuid4()
event_id = 'test_event_id'
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that search_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id=None)
mock_event_service.search_events.return_value = mock_event_page
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
limit=10,
)
# Verify the result
assert result == mock_event_page
assert len(result.items) == 0 # Empty list as we mocked
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
page_id=None,
limit=10,
)
async def test_search_shared_events_returns_empty_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that search_shared_events returns empty page for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
limit=10,
)
# Verify the result
assert isinstance(result, EventPage)
assert len(result.items) == 0
assert result.next_page_id is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.search_events.assert_not_called()
async def test_count_shared_events_returns_count_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test that count_shared_events returns count for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return a count
mock_event_service.count_events.return_value = 5
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
)
# Verify the result
assert result == 5
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.count_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
)
async def test_count_shared_events_returns_zero_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that count_shared_events returns 0 for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
)
# Verify the result
assert result == 0
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.count_events.assert_not_called()
async def test_batch_get_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that batch_get_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
event_ids = ['event1', 'event2']
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_service.get_event.side_effect = [sample_event, None]
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] == sample_event
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Verify that get_event was called for each event
assert mock_event_service.get_event.call_count == 2
async def test_batch_get_shared_events_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that batch_get_shared_events returns None for a private conversation."""
conversation_id = uuid4()
event_ids = ['event1', 'event2']
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] is None
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_with_all_parameters(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test search_shared_events with all parameters."""
conversation_id = sample_public_conversation.id
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id='next_page')
mock_event_service.search_events.return_value = mock_event_page
# Call the method with all parameters
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)
# Verify the result
assert result == mock_event_page
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)