mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 06:23:59 -05:00
ALL-4634: implement public conversation sharing feature (#12044)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""add public column to conversation_metadata
|
||||
|
||||
Revision ID: 085
|
||||
Revises: 084
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '085'
|
||||
down_revision: Union[str, None] = '084'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('public', sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
'conversation_metadata',
|
||||
['public'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
table_name='conversation_metadata',
|
||||
)
|
||||
op.drop_column('conversation_metadata', 'public')
|
||||
2
enterprise/poetry.lock
generated
2
enterprise/poetry.lock
generated
@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "0.0.0-post.5687+7853b41ad"
|
||||
version = "0.0.0-post.5750+f19fb1043"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
|
||||
@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
base_app.include_router(shared_conversation_router)
|
||||
base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
@@ -99,6 +107,7 @@ base_app.include_router(
|
||||
event_webhook_router
|
||||
) # Add routes for Events in nested runtimes
|
||||
|
||||
|
||||
base_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=PERMITTED_CORS_ORIGINS,
|
||||
|
||||
20
enterprise/server/sharing/README.md
Normal file
20
enterprise/server/sharing/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Sharing Package
|
||||
|
||||
This package contains functionality for sharing conversations.
|
||||
|
||||
## Components
|
||||
|
||||
- **shared.py**: Data models for shared conversations
|
||||
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
|
||||
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
|
||||
- **shared_event_service.py**: Service interface for accessing shared events
|
||||
- **shared_event_service_impl.py**: Implementation of the shared event service
|
||||
- **shared_conversation_router.py**: REST API endpoints for shared conversations
|
||||
- **shared_event_router.py**: REST API endpoints for shared events
|
||||
|
||||
## Features
|
||||
|
||||
- Read-only access to shared conversations
|
||||
- Event access for shared conversations
|
||||
- Search and filtering capabilities
|
||||
- Pagination support
|
||||
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Implementation of SharedEventService.
|
||||
|
||||
This implementation provides read-only access to events from shared conversations:
|
||||
- Validates that the conversation is shared before returning events
|
||||
- Uses existing EventService for actual event retrieval
|
||||
- Uses SharedConversationInfoService for shared conversation validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedEventServiceImpl(SharedEventService):
|
||||
"""Implementation of SharedEventService that validates shared access."""
|
||||
|
||||
shared_conversation_info_service: SharedConversationInfoService
|
||||
event_service: EventService
|
||||
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return None
|
||||
|
||||
# If conversation is shared, get the event
|
||||
return await self.event_service.get_event(event_id)
|
||||
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return EventPage(items=[], next_page_id=None)
|
||||
|
||||
# If conversation is shared, search events for this conversation
|
||||
return await self.event_service.search_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return 0
|
||||
|
||||
# If conversation is shared, count events for this conversation
|
||||
return await self.event_service.count_events(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceImplInjector(SharedEventServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedEventService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_event_service,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_db_session(state, request) as db_session,
|
||||
get_event_service(state, request) as event_service,
|
||||
):
|
||||
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||
db_session=db_session
|
||||
)
|
||||
service = SharedEventServiceImpl(
|
||||
shared_conversation_info_service=shared_conversation_info_service,
|
||||
event_service=event_service,
|
||||
)
|
||||
yield service
|
||||
@@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
|
||||
class SharedConversationInfoService(ABC):
|
||||
"""Service for accessing shared conversation info without user restrictions."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single shared conversation info, returning None if missing or not shared."""
|
||||
|
||||
async def batch_get_shared_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_conversation_info(conversation_id)
|
||||
for conversation_id in conversation_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedConversationInfoServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
|
||||
):
|
||||
pass
|
||||
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
# Simplified imports to avoid dependency chain issues
|
||||
# from openhands.integrations.service_types import ProviderType
|
||||
# from openhands.sdk.llm import MetricsSnapshot
|
||||
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
# For now, use Any to avoid import issues
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.agent_server.utils import OpenHandsUUID, utc_now
|
||||
|
||||
ProviderType = Any
|
||||
MetricsSnapshot = Any
|
||||
ConversationTrigger = Any
|
||||
|
||||
|
||||
class SharedConversation(BaseModel):
|
||||
"""Shared conversation info model with all fields from AppConversationInfo."""
|
||||
|
||||
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||
|
||||
created_by_user_id: str | None
|
||||
sandbox_id: str
|
||||
|
||||
selected_repository: str | None = None
|
||||
selected_branch: str | None = None
|
||||
git_provider: ProviderType | None = None
|
||||
title: str | None = None
|
||||
pr_number: list[int] = Field(default_factory=list)
|
||||
llm_model: str | None = None
|
||||
|
||||
metrics: MetricsSnapshot | None = None
|
||||
|
||||
parent_conversation_id: OpenHandsUUID | None = None
|
||||
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
|
||||
class SharedConversationSortOrder(Enum):
|
||||
CREATED_AT = 'CREATED_AT'
|
||||
CREATED_AT_DESC = 'CREATED_AT_DESC'
|
||||
UPDATED_AT = 'UPDATED_AT'
|
||||
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
|
||||
TITLE = 'TITLE'
|
||||
TITLE_DESC = 'TITLE_DESC'
|
||||
|
||||
|
||||
class SharedConversationPage(BaseModel):
|
||||
items: list[SharedConversation]
|
||||
next_page_id: str | None = None
|
||||
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Shared Conversation router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoServiceInjector,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
|
||||
shared_conversation_info_service_dependency = Depends(
|
||||
SQLSharedConversationInfoServiceInjector().depends
|
||||
)
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
SharedConversationSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
bool,
|
||||
Query(
|
||||
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
|
||||
),
|
||||
] = False,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
include_sub_conversations=include_sub_conversations,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_conversations(
|
||||
title__contains: Annotated[
|
||||
str | None,
|
||||
Query(title='Filter by title containing this string'),
|
||||
] = None,
|
||||
created_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
created_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by created_at less than this datetime'),
|
||||
] = None,
|
||||
updated_at__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||
] = None,
|
||||
updated_at__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Filter by updated_at less than this datetime'),
|
||||
] = None,
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
return await shared_conversation_service.count_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_conversations(
|
||||
ids: Annotated[list[str], Query()],
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
assert len(ids) <= 100
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
)
|
||||
return shared_conversation_info
|
||||
126
enterprise/server/sharing/shared_event_router.py
Normal file
126
enterprise/server/sharing/shared_event_router.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Shared Event router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImplInjector,
|
||||
)
|
||||
from server.sharing.shared_event_service import SharedEventService
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
|
||||
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
|
||||
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
|
||||
|
||||
|
||||
# Read methods
|
||||
|
||||
|
||||
@router.get('/search')
|
||||
async def search_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to search events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get('/count')
|
||||
async def count_shared_events(
|
||||
conversation_id: Annotated[
|
||||
str,
|
||||
Query(title='Conversation ID to count events for'),
|
||||
],
|
||||
kind__eq: Annotated[
|
||||
EventKind | None,
|
||||
Query(title='Optional filter by event kind'),
|
||||
] = None,
|
||||
timestamp__gte: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||
] = None,
|
||||
timestamp__lt: Annotated[
|
||||
datetime | None,
|
||||
Query(title='Optional filter by timestamp less than'),
|
||||
] = None,
|
||||
sort_order: Annotated[
|
||||
EventSortOrder,
|
||||
Query(title='Sort order for results'),
|
||||
] = EventSortOrder.TIMESTAMP,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> int:
|
||||
"""Count events for a shared conversation matching the given filters."""
|
||||
return await shared_event_service.count_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
|
||||
@router.get('')
|
||||
async def batch_get_shared_events(
|
||||
conversation_id: Annotated[
|
||||
UUID,
|
||||
Query(title='Conversation ID to get events for'),
|
||||
],
|
||||
id: Annotated[list[str], Query()],
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
assert len(id) <= 100
|
||||
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
|
||||
return events
|
||||
|
||||
|
||||
@router.get('/{conversation_id}/{event_id}')
|
||||
async def get_shared_event(
|
||||
conversation_id: UUID,
|
||||
event_id: str,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> Event | None:
|
||||
"""Get a single event from a shared conversation by conversation_id and event_id."""
|
||||
return await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
64
enterprise/server/sharing/shared_event_service.py
Normal file
64
enterprise/server/sharing/shared_event_service.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedEventService(ABC):
|
||||
"""Event Service for getting events from shared conversations only."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: str
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
|
||||
@abstractmethod
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
|
||||
@abstractmethod
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
|
||||
async def batch_get_shared_events(
|
||||
self, conversation_id: UUID, event_ids: list[str]
|
||||
) -> list[Event | None]:
|
||||
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.get_shared_event(conversation_id, event_id)
|
||||
for event_id in event_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SharedEventServiceInjector(
|
||||
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,282 @@
|
||||
"""SQL implementation of SharedConversationInfoService.
|
||||
|
||||
This implementation provides read-only access to shared conversations:
|
||||
- Direct database access without user permission checks
|
||||
- Filters only conversations marked as shared (currently public)
|
||||
- Full async/await support using SQL async db_sessions
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
SharedConversationInfoServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == SharedConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == SharedConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return SharedConversationPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_shared_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count shared conversations matching the given filters."""
|
||||
from sqlalchemy import func
|
||||
|
||||
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||
# Only include shared conversations
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
|
||||
query = self._apply_filters(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_shared_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
query = select(StoredConversationMetadata).where(
|
||||
StoredConversationMetadata.conversation_version == 'V1'
|
||||
)
|
||||
# Only include conversations marked as public
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply common filters to a query."""
|
||||
if title__contains is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.title.contains(title__contains)
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.created_at >= created_at__gte
|
||||
)
|
||||
|
||||
if created_at__lt is not None:
|
||||
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
|
||||
# Rebuild token usage
|
||||
token_usage = TokenUsage(
|
||||
prompt_tokens=stored.prompt_tokens,
|
||||
completion_tokens=stored.completion_tokens,
|
||||
cache_read_tokens=stored.cache_read_tokens,
|
||||
cache_write_tokens=stored.cache_write_tokens,
|
||||
context_window=stored.context_window,
|
||||
per_turn_token=stored.per_turn_token,
|
||||
)
|
||||
|
||||
# Rebuild metrics object
|
||||
metrics = MetricsSnapshot(
|
||||
accumulated_cost=stored.accumulated_cost,
|
||||
max_budget_per_task=stored.max_budget_per_task,
|
||||
accumulated_token_usage=token_usage,
|
||||
)
|
||||
|
||||
# Get timestamps
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
git_provider=(
|
||||
ProviderType(stored.git_provider) if stored.git_provider else None
|
||||
),
|
||||
title=stored.title,
|
||||
pr_number=stored.pr_number,
|
||||
llm_model=stored.llm_model,
|
||||
metrics=metrics,
|
||||
parent_conversation_id=(
|
||||
UUID(stored.parent_conversation_id)
|
||||
if stored.parent_conversation_id
|
||||
else None
|
||||
),
|
||||
sub_conversation_ids=sub_conversation_ids or [],
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _fix_timezone(self, value: datetime) -> datetime:
|
||||
"""Sqlite does not store timezones - and since we can't update the existing models
|
||||
we assume UTC if the timezone is missing."""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
return value
|
||||
|
||||
|
||||
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedConversationInfoService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state, request) as db_session:
|
||||
service = SQLSharedConversationInfoService(db_session=db_session)
|
||||
yield service
|
||||
@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
|
||||
kwargs.pop('context_window', None)
|
||||
kwargs.pop('per_turn_token', None)
|
||||
kwargs.pop('parent_conversation_id', None)
|
||||
kwargs.pop('public')
|
||||
|
||||
return ConversationMetadata(**kwargs)
|
||||
|
||||
|
||||
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for sharing package."""
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Tests for public conversation models."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversation,
|
||||
SharedConversationPage,
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
|
||||
|
||||
def test_public_conversation_creation():
|
||||
"""Test that SharedConversation can be created with all required fields."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
assert conversation.id == conversation_id
|
||||
assert conversation.title == 'Test Conversation'
|
||||
assert conversation.created_by_user_id == 'test_user'
|
||||
assert conversation.sandbox_id == 'test_sandbox'
|
||||
|
||||
|
||||
def test_public_conversation_page_creation():
|
||||
"""Test that SharedConversationPage can be created."""
|
||||
conversation_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository=None,
|
||||
parent_conversation_id=None,
|
||||
)
|
||||
|
||||
page = SharedConversationPage(
|
||||
items=[conversation],
|
||||
next_page_id='next_page',
|
||||
)
|
||||
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conversation_id
|
||||
assert page.next_page_id == 'next_page'
|
||||
|
||||
|
||||
def test_public_conversation_sort_order_enum():
|
||||
"""Test that SharedConversationSortOrder enum has expected values."""
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
|
||||
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE')
|
||||
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
|
||||
|
||||
|
||||
def test_public_conversation_optional_fields():
|
||||
"""Test that SharedConversation works with optional fields."""
|
||||
conversation_id = uuid4()
|
||||
parent_id = uuid4()
|
||||
now = datetime.utcnow()
|
||||
|
||||
conversation = SharedConversation(
|
||||
id=conversation_id,
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Conversation',
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
selected_repository='owner/repo',
|
||||
parent_conversation_id=parent_id,
|
||||
llm_model='gpt-4',
|
||||
)
|
||||
|
||||
assert conversation.selected_repository == 'owner/repo'
|
||||
assert conversation.parent_conversation_id == parent_id
|
||||
assert conversation.llm_model == 'gpt-4'
|
||||
@@ -0,0 +1,430 @@
|
||||
"""Tests for SharedConversationInfoService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_conversation_models import (
|
||||
SharedConversationSortOrder,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def shared_conversation_info_service(async_session):
|
||||
"""Create a SharedConversationInfoService for testing."""
|
||||
return SQLSharedConversationInfoService(db_session=async_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_conversation_service(async_session):
|
||||
"""Create an AppConversationInfoService for creating test data."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info():
|
||||
"""Create a sample conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
selected_repository='test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=1.5,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=150,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=True, # Make it public for testing
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_private_conversation_info():
|
||||
"""Create a sample private conversation info for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_private',
|
||||
selected_repository='test/private_repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Private Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[124],
|
||||
llm_model='gpt-4',
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=2.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=4096,
|
||||
per_turn_token=300,
|
||||
),
|
||||
),
|
||||
parent_conversation_id=None,
|
||||
sub_conversation_ids=[],
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
public=False, # Make it private
|
||||
)
|
||||
|
||||
|
||||
class TestSharedConversationInfoService:
|
||||
"""Test cases for SharedConversationInfoService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_public_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns a public conversation."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == sample_conversation_info.id
|
||||
assert result.title == sample_conversation_info.title
|
||||
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for private conversations."""
|
||||
# Create a private conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Try to retrieve it via public service
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
sample_private_conversation_info.id
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
|
||||
self, shared_conversation_info_service
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
|
||||
nonexistent_id = uuid4()
|
||||
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||
nonexistent_id
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_returns_only_public_conversations(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test that search only returns public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Search for all conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info()
|
||||
)
|
||||
|
||||
# Should only return the public conversation
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].id == sample_conversation_info.id
|
||||
assert result.items[0].title == sample_conversation_info.title
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_title_filter(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
):
|
||||
"""Test searching with title filter."""
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
|
||||
# Search with matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='Test'
|
||||
)
|
||||
assert len(result.items) == 1
|
||||
|
||||
# Search with non-matching title
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
title__contains='NonExistent'
|
||||
)
|
||||
assert len(result.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversation_info_with_sort_order(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test searching with different sort orders."""
|
||||
# Create multiple public conversations with different titles and timestamps
|
||||
conv1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_1',
|
||||
title='A First Conversation',
|
||||
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conv2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox_2',
|
||||
title='B Second Conversation',
|
||||
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
await app_conversation_service.save_app_conversation_info(conv1)
|
||||
await app_conversation_service.save_app_conversation_info(conv2)
|
||||
|
||||
# Test sort by title ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'A First Conversation'
|
||||
assert result.items[1].title == 'B Second Conversation'
|
||||
|
||||
# Test sort by title descending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.TITLE_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].title == 'B Second Conversation'
|
||||
assert result.items[1].title == 'A First Conversation'
|
||||
|
||||
# Test sort by created_at ascending
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv1.id
|
||||
assert result.items[1].id == conv2.id
|
||||
|
||||
# Test sort by created_at descending (default)
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].id == conv2.id
|
||||
assert result.items[1].id == conv1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test counting public conversations."""
|
||||
# Initially should be 0
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 0
|
||||
|
||||
# Create a public conversation
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
# Create a private conversation - count should remain 1
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_get_shared_conversation_info(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
sample_conversation_info,
|
||||
sample_private_conversation_info,
|
||||
):
|
||||
"""Test batch getting public conversations."""
|
||||
# Create both public and private conversations
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_conversation_info
|
||||
)
|
||||
await app_conversation_service.save_app_conversation_info(
|
||||
sample_private_conversation_info
|
||||
)
|
||||
|
||||
# Batch get both conversations
|
||||
result = (
|
||||
await shared_conversation_info_service.batch_get_shared_conversation_info(
|
||||
[sample_conversation_info.id, sample_private_conversation_info.id]
|
||||
)
|
||||
)
|
||||
|
||||
# Should return the public one and None for the private one
|
||||
assert len(result) == 2
|
||||
assert result[0] is not None
|
||||
assert result[0].id == sample_conversation_info.id
|
||||
assert result[1] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_pagination(
|
||||
self,
|
||||
shared_conversation_info_service,
|
||||
app_conversation_service,
|
||||
):
|
||||
"""Test search with pagination."""
|
||||
# Create multiple public conversations
|
||||
conversations = []
|
||||
for i in range(5):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id=f'test_sandbox_{i}',
|
||||
title=f'Conversation {i}',
|
||||
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conversations.append(conv)
|
||||
await app_conversation_service.save_app_conversation_info(conv)
|
||||
|
||||
# Get first page with limit 2
|
||||
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
assert len(result.items) == 2
|
||||
assert result.next_page_id is not None
|
||||
|
||||
# Get next page
|
||||
result2 = (
|
||||
await shared_conversation_info_service.search_shared_conversation_info(
|
||||
limit=2,
|
||||
page_id=result.next_page_id,
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT,
|
||||
)
|
||||
)
|
||||
assert len(result2.items) == 2
|
||||
assert result2.next_page_id is not None
|
||||
|
||||
# Verify no overlap between pages
|
||||
page1_ids = {item.id for item in result.items}
|
||||
page2_ids = {item.id for item in result2.items}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
@@ -0,0 +1,365 @@
|
||||
"""Tests for SharedEventService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
SharedEventServiceImpl,
|
||||
)
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import SharedConversation
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_shared_conversation_info_service():
|
||||
"""Create a mock SharedConversationInfoService."""
|
||||
return AsyncMock(spec=SharedConversationInfoService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_service():
|
||||
"""Create a mock EventService."""
|
||||
return AsyncMock(spec=EventService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
|
||||
"""Create a SharedEventService for testing."""
|
||||
return SharedEventServiceImpl(
|
||||
shared_conversation_info_service=mock_shared_conversation_info_service,
|
||||
event_service=mock_event_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_public_conversation():
|
||||
"""Create a sample public conversation."""
|
||||
return SharedConversation(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Public Conversation',
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event():
|
||||
"""Create a sample event."""
|
||||
# For testing purposes, we'll just use a mock that the EventPage can accept
|
||||
# The actual event creation is complex and not the focus of these tests
|
||||
return None
|
||||
|
||||
|
||||
class TestSharedEventService:
|
||||
"""Test cases for SharedEventService."""
|
||||
|
||||
async def test_get_shared_event_returns_event_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that get_shared_event returns an event for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_id = 'test_event_id'
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return an event
|
||||
mock_event_service.get_event.return_value = sample_event
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
|
||||
# Verify the result
|
||||
assert result == sample_event
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.get_event.assert_called_once_with(event_id)
|
||||
|
||||
async def test_get_shared_event_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that get_shared_event returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_id = 'test_event_id'
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||
|
||||
# Verify the result
|
||||
assert result is None
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that search_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id=None)
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
assert len(result.items) == 0 # Empty list as we mocked
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
page_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
async def test_search_shared_events_returns_empty_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that search_shared_events returns empty page for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EventPage)
|
||||
assert len(result.items) == 0
|
||||
assert result.next_page_id is None
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.search_events.assert_not_called()
|
||||
|
||||
async def test_count_shared_events_returns_count_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test that count_shared_events returns count for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return a count
|
||||
mock_event_service.count_events.return_value = 5
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 5
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.count_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
)
|
||||
|
||||
async def test_count_shared_events_returns_zero_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that count_shared_events returns 0 for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 0
|
||||
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.count_events.assert_not_called()
|
||||
|
||||
async def test_batch_get_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_ids = ['event1', 'event2']
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_service.get_event.side_effect = [sample_event, None]
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert result[0] == sample_event
|
||||
assert result[1] is None
|
||||
|
||||
# Verify that get_shared_conversation_info was called for each event
|
||||
assert (
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||
== 2
|
||||
)
|
||||
# Verify that get_event was called for each event
|
||||
assert mock_event_service.get_event.call_count == 2
|
||||
|
||||
async def test_batch_get_shared_events_returns_none_for_private_conversation(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_ids = ['event1', 'event2']
|
||||
|
||||
# Mock the public conversation service to return None (private conversation)
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(result) == 2
|
||||
assert result[0] is None
|
||||
assert result[1] is None
|
||||
|
||||
# Verify that get_shared_conversation_info was called for each event
|
||||
assert (
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||
== 2
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_with_all_parameters(
|
||||
self,
|
||||
shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test search_shared_events with all parameters."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
|
||||
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id='next_page')
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method with all parameters
|
||||
result = await shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ObservationEvent',
|
||||
timestamp__gte=timestamp_gte,
|
||||
timestamp__lt=timestamp_lt,
|
||||
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||
page_id='current_page',
|
||||
limit=50,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id__eq=conversation_id,
|
||||
kind__eq='ObservationEvent',
|
||||
timestamp__gte=timestamp_gte,
|
||||
timestamp__lt=timestamp_lt,
|
||||
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||
page_id='current_page',
|
||||
limit=50,
|
||||
)
|
||||
@@ -45,6 +45,8 @@ class AppConversationInfo(BaseModel):
|
||||
parent_conversation_id: OpenHandsUUID | None = None
|
||||
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||
|
||||
public: bool | None = None
|
||||
|
||||
created_at: datetime = Field(default_factory=utc_now)
|
||||
updated_at: datetime = Field(default_factory=utc_now)
|
||||
|
||||
@@ -114,6 +116,12 @@ class AppConversationStartRequest(BaseModel):
|
||||
parent_conversation_id: OpenHandsUUID | None = None
|
||||
agent_type: AgentType = Field(default=AgentType.DEFAULT)
|
||||
|
||||
public: bool | None = None
|
||||
|
||||
|
||||
class AppConversationUpdateRequest(BaseModel):
|
||||
public: bool
|
||||
|
||||
|
||||
class AppConversationStartTaskStatus(Enum):
|
||||
WORKING = 'WORKING'
|
||||
|
||||
@@ -40,6 +40,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTask,
|
||||
AppConversationStartTaskPage,
|
||||
AppConversationStartTaskSortOrder,
|
||||
AppConversationUpdateRequest,
|
||||
SkillResponse,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
@@ -222,6 +223,22 @@ async def start_app_conversation(
|
||||
raise
|
||||
|
||||
|
||||
@router.patch('/{conversation_id}')
|
||||
async def update_app_conversation(
|
||||
conversation_id: str,
|
||||
update_request: AppConversationUpdateRequest,
|
||||
app_conversation_service: AppConversationService = (
|
||||
app_conversation_service_dependency
|
||||
),
|
||||
) -> AppConversation:
|
||||
info = await app_conversation_service.update_app_conversation(
|
||||
UUID(conversation_id), update_request
|
||||
)
|
||||
if info is None:
|
||||
raise HTTPException(404, 'unknown_app_conversation')
|
||||
return info
|
||||
|
||||
|
||||
@router.post('/stream-start')
|
||||
async def stream_app_conversation_start(
|
||||
request: AppConversationStartRequest,
|
||||
|
||||
@@ -10,6 +10,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationSortOrder,
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTask,
|
||||
AppConversationUpdateRequest,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.services.injector import Injector
|
||||
@@ -98,6 +99,13 @@ class AppConversationService(ABC):
|
||||
"""Run the setup scripts for the project and yield status updates"""
|
||||
yield task
|
||||
|
||||
@abstractmethod
|
||||
async def update_app_conversation(
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
"""Delete a V1 conversation and all its associated data.
|
||||
|
||||
@@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTask,
|
||||
AppConversationStartTaskStatus,
|
||||
AppConversationUpdateRequest,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
AppConversationService,
|
||||
@@ -1049,6 +1050,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
f'Successfully updated agent-server conversation {conversation_id} title to "{new_title}"'
|
||||
)
|
||||
|
||||
async def update_app_conversation(
|
||||
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||
) -> AppConversation | None:
|
||||
"""Update an app conversation and return it. Return None if the conversation
|
||||
did not exist."""
|
||||
info = await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
for field_name in request.model_fields:
|
||||
value = getattr(request, field_name)
|
||||
setattr(info, field_name, value)
|
||||
info = await self.app_conversation_info_service.save_app_conversation_info(info)
|
||||
conversations = await self._build_app_conversations([info])
|
||||
return conversations[0]
|
||||
|
||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||
"""Delete a V1 conversation and all its associated data.
|
||||
|
||||
|
||||
@@ -25,7 +25,17 @@ from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, Select, String, func, select
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
Select,
|
||||
String,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.agent_server.utils import utc_now
|
||||
@@ -91,6 +101,7 @@ class StoredConversationMetadata(Base): # type: ignore
|
||||
conversation_version = Column(String, nullable=False, default='V0', index=True)
|
||||
sandbox_id = Column(String, nullable=True, index=True)
|
||||
parent_conversation_id = Column(String, nullable=True, index=True)
|
||||
public = Column(Boolean, nullable=True, index=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -350,6 +361,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
if info.parent_conversation_id
|
||||
else None
|
||||
),
|
||||
public=info.public,
|
||||
)
|
||||
|
||||
await self.db_session.merge(stored)
|
||||
@@ -541,6 +553,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
||||
else None
|
||||
),
|
||||
sub_conversation_ids=sub_conversation_ids or [],
|
||||
public=stored.public,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
41
openhands/app_server/app_lifespan/alembic/versions/004.py
Normal file
41
openhands/app_server/app_lifespan/alembic/versions/004.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""add public column to conversation_metadata
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2025-01-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '004'
|
||||
down_revision: Union[str, None] = '003'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('public', sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
'conversation_metadata',
|
||||
['public'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_index(
|
||||
op.f('ix_conversation_metadata_public'),
|
||||
table_name='conversation_metadata',
|
||||
)
|
||||
op.drop_column('conversation_metadata', 'public')
|
||||
@@ -22,7 +22,7 @@ event_service_dependency = depends_event_service()
|
||||
@router.get('/search')
|
||||
async def search_events(
|
||||
conversation_id__eq: Annotated[
|
||||
UUID | None,
|
||||
str | None,
|
||||
Query(title='Optional filter by conversation ID'),
|
||||
] = None,
|
||||
kind__eq: Annotated[
|
||||
@@ -55,7 +55,7 @@ async def search_events(
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await event_service.search_events(
|
||||
conversation_id__eq=conversation_id__eq,
|
||||
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
@@ -68,7 +68,7 @@ async def search_events(
|
||||
@router.get('/count')
|
||||
async def count_events(
|
||||
conversation_id__eq: Annotated[
|
||||
UUID | None,
|
||||
str | None,
|
||||
Query(title='Optional filter by conversation ID'),
|
||||
] = None,
|
||||
kind__eq: Annotated[
|
||||
@@ -91,7 +91,7 @@ async def count_events(
|
||||
) -> int:
|
||||
"""Count events matching the given filters."""
|
||||
return await event_service.count_events(
|
||||
conversation_id__eq=conversation_id__eq,
|
||||
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
|
||||
@@ -1,32 +1,27 @@
|
||||
"""Filesystem-based EventService implementation."""
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.errors import OpenHandsError
|
||||
from openhands.app_server.event.event_service import EventService, EventServiceInjector
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.event.filesystem_event_service_base import (
|
||||
FilesystemEventServiceBase,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilesystemEventService(EventService):
|
||||
class FilesystemEventService(FilesystemEventServiceBase, EventService):
|
||||
"""Filesystem-based implementation of EventService.
|
||||
|
||||
Events are stored in files with the naming format:
|
||||
@@ -47,25 +42,6 @@ class FilesystemEventService(EventService):
|
||||
events_path.mkdir(parents=True, exist_ok=True)
|
||||
return events_path
|
||||
|
||||
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
|
||||
"""Convert timestamp to YYYYMMDDHHMMSS format."""
|
||||
if isinstance(timestamp, str):
|
||||
# Parse ISO format timestamp string
|
||||
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||
return dt.strftime('%Y%m%d%H%M%S')
|
||||
return timestamp.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
|
||||
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
|
||||
timestamp_str = self._timestamp_to_str(event.timestamp)
|
||||
kind = event.__class__.__name__
|
||||
# Handle both UUID objects and string UUIDs
|
||||
if isinstance(event.id, str):
|
||||
id_hex = event.id.replace('-', '')
|
||||
else:
|
||||
id_hex = event.id.hex
|
||||
return f'{timestamp_str}_{kind}_{id_hex}'
|
||||
|
||||
def _save_event_to_file(self, conversation_id: UUID, event: Event) -> None:
|
||||
"""Save an event to a file."""
|
||||
events_path = self._ensure_events_dir(conversation_id)
|
||||
@@ -77,60 +53,17 @@ class FilesystemEventService(EventService):
|
||||
data = event.model_dump(mode='json')
|
||||
f.write(json.dumps(data, indent=2))
|
||||
|
||||
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
|
||||
events = []
|
||||
for file_path in file_paths:
|
||||
event = self._load_event_from_file(file_path)
|
||||
if event is not None:
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
def _load_event_from_file(self, filepath: Path) -> Event | None:
|
||||
"""Load an event from a file."""
|
||||
try:
|
||||
json_data = filepath.read_text()
|
||||
return Event.model_validate_json(json_data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _get_event_files_by_pattern(
|
||||
self, pattern: str, conversation_id: UUID | None = None
|
||||
) -> list[Path]:
|
||||
"""Get event files matching a glob pattern, sorted by timestamp."""
|
||||
if conversation_id:
|
||||
search_path = self.events_dir / str(conversation_id) / pattern
|
||||
else:
|
||||
search_path = self.events_dir / '*' / pattern
|
||||
|
||||
files = glob.glob(str(search_path))
|
||||
return sorted([Path(f) for f in files])
|
||||
|
||||
def _parse_filename(self, filename: str) -> dict[str, str] | None:
|
||||
"""Parse filename to extract timestamp, kind, and event_id."""
|
||||
try:
|
||||
parts = filename.split('_')
|
||||
if len(parts) >= 3:
|
||||
timestamp_str = parts[0]
|
||||
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
|
||||
event_id = parts[-1]
|
||||
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _get_conversation_id(self, file: Path) -> UUID | None:
|
||||
try:
|
||||
return UUID(file.parent.name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
|
||||
result = set()
|
||||
for file in files:
|
||||
conversation_id = self._get_conversation_id(file)
|
||||
if conversation_id:
|
||||
result.add(conversation_id)
|
||||
return result
|
||||
async def save_event(self, conversation_id: UUID, event: Event):
|
||||
"""Save an event. Internal method intended not be part of the REST api."""
|
||||
conversation = (
|
||||
await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if not conversation:
|
||||
# This is either an illegal state or somebody is trying to hack
|
||||
raise OpenHandsError('No such conversation: {conversaiont_id}')
|
||||
self._save_event_to_file(conversation_id, event)
|
||||
|
||||
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
|
||||
conversation_ids = list(self._get_conversation_ids(files))
|
||||
@@ -150,161 +83,6 @@ class FilesystemEventService(EventService):
|
||||
]
|
||||
return result
|
||||
|
||||
def _filter_files_by_criteria(
|
||||
self,
|
||||
files: list[Path],
|
||||
conversation_id__eq: UUID | None = None,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
) -> list[Path]:
|
||||
"""Filter files based on search criteria."""
|
||||
filtered_files = []
|
||||
|
||||
for file_path in files:
|
||||
# Check conversation_id filter
|
||||
if conversation_id__eq:
|
||||
if str(conversation_id__eq) not in str(file_path):
|
||||
continue
|
||||
|
||||
# Parse filename for additional filtering
|
||||
filename_info = self._parse_filename(file_path.name)
|
||||
if not filename_info:
|
||||
continue
|
||||
|
||||
# Check kind filter
|
||||
if kind__eq and filename_info['kind'] != kind__eq:
|
||||
continue
|
||||
|
||||
# Check timestamp filters
|
||||
if timestamp__gte or timestamp__lt:
|
||||
try:
|
||||
file_timestamp = datetime.strptime(
|
||||
filename_info['timestamp'], '%Y%m%d%H%M%S'
|
||||
)
|
||||
if timestamp__gte and file_timestamp < timestamp__gte:
|
||||
continue
|
||||
if timestamp__lt and file_timestamp >= timestamp__lt:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
filtered_files.append(file_path)
|
||||
|
||||
return filtered_files
|
||||
|
||||
async def get_event(self, event_id: str) -> Event | None:
|
||||
"""Get the event with the given id, or None if not found."""
|
||||
# Convert event_id to hex format (remove dashes) for filename matching
|
||||
if isinstance(event_id, str) and '-' in event_id:
|
||||
id_hex = event_id.replace('-', '')
|
||||
else:
|
||||
id_hex = event_id
|
||||
|
||||
# Use glob pattern to find files ending with the event_id
|
||||
pattern = f'*_{id_hex}'
|
||||
files = self._get_event_files_by_pattern(pattern)
|
||||
|
||||
if not files:
|
||||
return None
|
||||
|
||||
# If there is no access to the conversation do not return the event
|
||||
file = files[0]
|
||||
conversation_id = self._get_conversation_id(file)
|
||||
if not conversation_id:
|
||||
return None
|
||||
conversation = (
|
||||
await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
# Load and return the first matching event
|
||||
return self._load_event_from_file(file)
|
||||
|
||||
async def search_events(
|
||||
self,
|
||||
conversation_id__eq: UUID | None = None,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search for events matching the given filters."""
|
||||
# Build the search pattern
|
||||
pattern = '*'
|
||||
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
||||
|
||||
files = await self._filter_files_by_conversation(files)
|
||||
|
||||
files = self._filter_files_by_criteria(
|
||||
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
||||
)
|
||||
|
||||
files.sort(
|
||||
key=lambda f: f.name,
|
||||
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
|
||||
)
|
||||
|
||||
# Handle pagination
|
||||
start_index = 0
|
||||
if page_id:
|
||||
for i, file_path in enumerate(files):
|
||||
if file_path.name == page_id:
|
||||
start_index = i + 1
|
||||
break
|
||||
|
||||
# Collect items for this page
|
||||
page_files = files[start_index : start_index + limit]
|
||||
next_page_id = None
|
||||
if start_index + limit < len(files):
|
||||
next_page_id = files[start_index + limit].name
|
||||
|
||||
# Load all events from files in a background thread.
|
||||
loop = asyncio.get_running_loop()
|
||||
page_events = await loop.run_in_executor(
|
||||
None, self._load_events_from_files, page_files
|
||||
)
|
||||
|
||||
return EventPage(items=page_events, next_page_id=next_page_id)
|
||||
|
||||
async def count_events(
|
||||
self,
|
||||
conversation_id__eq: UUID | None = None,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events matching the given filters."""
|
||||
# Build the search pattern
|
||||
pattern = '*'
|
||||
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
||||
|
||||
files = await self._filter_files_by_conversation(files)
|
||||
|
||||
files = self._filter_files_by_criteria(
|
||||
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
||||
)
|
||||
|
||||
return len(files)
|
||||
|
||||
async def save_event(self, conversation_id: UUID, event: Event):
|
||||
"""Save an event. Internal method intended not be part of the REST api."""
|
||||
conversation = (
|
||||
await self.app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if not conversation:
|
||||
# This is either an illegal state or somebody is trying to hack
|
||||
raise OpenHandsError('No such conversation: {conversaiont_id}')
|
||||
self._save_event_to_file(conversation_id, event)
|
||||
|
||||
|
||||
class FilesystemEventServiceInjector(EventServiceInjector):
|
||||
async def inject(
|
||||
|
||||
224
openhands/app_server/event/filesystem_event_service_base.py
Normal file
224
openhands/app_server/event/filesystem_event_service_base.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import asyncio
|
||||
import glob
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
|
||||
|
||||
class FilesystemEventServiceBase:
|
||||
events_dir: Path
|
||||
|
||||
async def get_event(self, event_id: str) -> Event | None:
|
||||
"""Get the event with the given id, or None if not found."""
|
||||
# Convert event_id to hex format (remove dashes) for filename matching
|
||||
if isinstance(event_id, str) and '-' in event_id:
|
||||
id_hex = event_id.replace('-', '')
|
||||
else:
|
||||
id_hex = event_id
|
||||
|
||||
# Use glob pattern to find files ending with the event_id
|
||||
pattern = f'*_{id_hex}'
|
||||
files = self._get_event_files_by_pattern(pattern)
|
||||
|
||||
files = await self._filter_files_by_conversation(files)
|
||||
|
||||
if not files:
|
||||
return None
|
||||
|
||||
# Load and return the first matching event
|
||||
return self._load_event_from_file(files[0])
|
||||
|
||||
async def search_events(
|
||||
self,
|
||||
conversation_id__eq: UUID | None = None,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search for events matching the given filters."""
|
||||
# Build the search pattern
|
||||
pattern = '*'
|
||||
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
||||
|
||||
files = await self._filter_files_by_conversation(files)
|
||||
|
||||
files = self._filter_files_by_criteria(
|
||||
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
||||
)
|
||||
|
||||
files.sort(
|
||||
key=lambda f: f.name,
|
||||
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
|
||||
)
|
||||
|
||||
# Handle pagination
|
||||
start_index = 0
|
||||
if page_id:
|
||||
for i, file_path in enumerate(files):
|
||||
if file_path.name == page_id:
|
||||
start_index = i + 1
|
||||
break
|
||||
|
||||
# Collect items for this page
|
||||
page_files = files[start_index : start_index + limit]
|
||||
next_page_id = None
|
||||
if start_index + limit < len(files):
|
||||
next_page_id = files[start_index + limit].name
|
||||
|
||||
# Load all events from files in a background thread.
|
||||
loop = asyncio.get_running_loop()
|
||||
page_events = await loop.run_in_executor(
|
||||
None, self._load_events_from_files, page_files
|
||||
)
|
||||
|
||||
return EventPage(items=page_events, next_page_id=next_page_id)
|
||||
|
||||
async def count_events(
|
||||
self,
|
||||
conversation_id__eq: UUID | None = None,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
) -> int:
|
||||
"""Count events matching the given filters."""
|
||||
# Build the search pattern
|
||||
pattern = '*'
|
||||
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
||||
|
||||
files = await self._filter_files_by_conversation(files)
|
||||
|
||||
files = self._filter_files_by_criteria(
|
||||
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
||||
)
|
||||
|
||||
return len(files)
|
||||
|
||||
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
|
||||
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
|
||||
timestamp_str = self._timestamp_to_str(event.timestamp)
|
||||
kind = event.__class__.__name__
|
||||
# Handle both UUID objects and string UUIDs
|
||||
if isinstance(event.id, str):
|
||||
id_hex = event.id.replace('-', '')
|
||||
else:
|
||||
id_hex = event.id.hex
|
||||
return f'{timestamp_str}_{kind}_{id_hex}'
|
||||
|
||||
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
|
||||
"""Convert timestamp to YYYYMMDDHHMMSS format."""
|
||||
if isinstance(timestamp, str):
|
||||
# Parse ISO format timestamp string
|
||||
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||
return dt.strftime('%Y%m%d%H%M%S')
|
||||
return timestamp.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
|
||||
events = []
|
||||
for file_path in file_paths:
|
||||
event = self._load_event_from_file(file_path)
|
||||
if event is not None:
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
def _load_event_from_file(self, filepath: Path) -> Event | None:
|
||||
"""Load an event from a file."""
|
||||
try:
|
||||
json_data = filepath.read_text()
|
||||
return Event.model_validate_json(json_data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _get_event_files_by_pattern(
|
||||
self, pattern: str, conversation_id: UUID | None = None
|
||||
) -> list[Path]:
|
||||
"""Get event files matching a glob pattern, sorted by timestamp."""
|
||||
if conversation_id:
|
||||
search_path = self.events_dir / str(conversation_id) / pattern
|
||||
else:
|
||||
search_path = self.events_dir / '*' / pattern
|
||||
|
||||
files = glob.glob(str(search_path))
|
||||
return sorted([Path(f) for f in files])
|
||||
|
||||
def _parse_filename(self, filename: str) -> dict[str, str] | None:
|
||||
"""Parse filename to extract timestamp, kind, and event_id."""
|
||||
try:
|
||||
parts = filename.split('_')
|
||||
if len(parts) >= 3:
|
||||
timestamp_str = parts[0]
|
||||
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
|
||||
event_id = parts[-1]
|
||||
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _get_conversation_id(self, file: Path) -> UUID | None:
|
||||
try:
|
||||
return UUID(file.parent.name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
|
||||
result = set()
|
||||
for file in files:
|
||||
conversation_id = self._get_conversation_id(file)
|
||||
if conversation_id:
|
||||
result.add(conversation_id)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
|
||||
"""Filter files by conversation."""
|
||||
|
||||
def _filter_files_by_criteria(
|
||||
self,
|
||||
files: list[Path],
|
||||
conversation_id__eq: UUID | None = None,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
) -> list[Path]:
|
||||
"""Filter files based on search criteria."""
|
||||
filtered_files = []
|
||||
|
||||
for file_path in files:
|
||||
# Check conversation_id filter
|
||||
if conversation_id__eq:
|
||||
if str(conversation_id__eq) not in str(file_path):
|
||||
continue
|
||||
|
||||
# Parse filename for additional filtering
|
||||
filename_info = self._parse_filename(file_path.name)
|
||||
if not filename_info:
|
||||
continue
|
||||
|
||||
# Check kind filter
|
||||
if kind__eq and filename_info['kind'] != kind__eq:
|
||||
continue
|
||||
|
||||
# Check timestamp filters
|
||||
if timestamp__gte or timestamp__lt:
|
||||
try:
|
||||
file_timestamp = datetime.strptime(
|
||||
filename_info['timestamp'], '%Y%m%d%H%M%S'
|
||||
)
|
||||
if timestamp__gte and file_timestamp < timestamp__gte:
|
||||
continue
|
||||
if timestamp__lt and file_timestamp >= timestamp__lt:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
filtered_files.append(file_path)
|
||||
|
||||
return filtered_files
|
||||
@@ -29,3 +29,4 @@ class ConversationInfo:
|
||||
pr_number: list[int] = field(default_factory=list)
|
||||
conversation_version: str = 'V0'
|
||||
sub_conversation_ids: list[str] = field(default_factory=list)
|
||||
public: bool | None = None
|
||||
|
||||
@@ -1501,4 +1501,5 @@ def _to_conversation_info(app_conversation: AppConversation) -> ConversationInfo
|
||||
sub_conversation_ids=[
|
||||
sub_id.hex for sub_id in app_conversation.sub_conversation_ids
|
||||
],
|
||||
public=app_conversation.public,
|
||||
)
|
||||
|
||||
@@ -39,3 +39,4 @@ class ConversationMetadata:
|
||||
# V1 compatibility
|
||||
sandbox_id: str | None = None
|
||||
conversation_version: str | None = None
|
||||
public: bool | None = None
|
||||
|
||||
15
poetry.lock
generated
15
poetry.lock
generated
@@ -12707,18 +12707,19 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "1.1.0"
|
||||
version = "1.3.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["test"]
|
||||
python-versions = ">=3.10"
|
||||
groups = ["dev", "test"]
|
||||
files = [
|
||||
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
|
||||
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
|
||||
{file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"},
|
||||
{file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=8.2,<9"
|
||||
pytest = ">=8.2,<10"
|
||||
typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
@@ -16823,4 +16824,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "9764f3b69ec8ed35feebd78a826bbc6bfa4ac6d5b56bc999be8bc738b644e538"
|
||||
content-hash = "e24ceb52bccd0c80f52c408215ccf007475eb69e10b895053ea49c7e3e4be3b8"
|
||||
|
||||
@@ -139,6 +139,7 @@ pre-commit = "4.2.0"
|
||||
build = "*"
|
||||
types-setuptools = "*"
|
||||
pytest = "^8.4.0"
|
||||
pytest-asyncio = "^1.3.0"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
Reference in New Issue
Block a user