mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 23:08:04 -05:00
ALL-4634: implement public conversation sharing feature (#12044)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -0,0 +1,41 @@
|
|||||||
|
"""add public column to conversation_metadata
|
||||||
|
|
||||||
|
Revision ID: 085
|
||||||
|
Revises: 084
|
||||||
|
Create Date: 2025-01-27 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '085'
|
||||||
|
down_revision: Union[str, None] = '084'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('public', sa.Boolean(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f('ix_conversation_metadata_public'),
|
||||||
|
'conversation_metadata',
|
||||||
|
['public'],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
op.drop_index(
|
||||||
|
op.f('ix_conversation_metadata_public'),
|
||||||
|
table_name='conversation_metadata',
|
||||||
|
)
|
||||||
|
op.drop_column('conversation_metadata', 'public')
|
||||||
2
enterprise/poetry.lock
generated
2
enterprise/poetry.lock
generated
@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openhands-ai"
|
name = "openhands-ai"
|
||||||
version = "0.0.0-post.5687+7853b41ad"
|
version = "0.0.0-post.5750+f19fb1043"
|
||||||
description = "OpenHands: Code Less, Make More"
|
description = "OpenHands: Code Less, Make More"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "^3.12,<3.14"
|
python-versions = "^3.12,<3.14"
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
|||||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||||
from server.routes.readiness import readiness_router # noqa: E402
|
from server.routes.readiness import readiness_router # noqa: E402
|
||||||
from server.routes.user import saas_user_router # noqa: E402
|
from server.routes.user import saas_user_router # noqa: E402
|
||||||
|
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||||
|
router as shared_conversation_router,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_event_router import ( # noqa: E402
|
||||||
|
router as shared_event_router,
|
||||||
|
)
|
||||||
|
|
||||||
from openhands.server.app import app as base_app # noqa: E402
|
from openhands.server.app import app as base_app # noqa: E402
|
||||||
from openhands.server.listen_socket import sio # noqa: E402
|
from openhands.server.listen_socket import sio # noqa: E402
|
||||||
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
|
|||||||
base_app.include_router(
|
base_app.include_router(
|
||||||
billing_router
|
billing_router
|
||||||
) # Add routes for credit management and Stripe payment integration
|
) # Add routes for credit management and Stripe payment integration
|
||||||
|
base_app.include_router(shared_conversation_router)
|
||||||
|
base_app.include_router(shared_event_router)
|
||||||
|
|
||||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||||
if GITHUB_APP_CLIENT_ID:
|
if GITHUB_APP_CLIENT_ID:
|
||||||
@@ -99,6 +107,7 @@ base_app.include_router(
|
|||||||
event_webhook_router
|
event_webhook_router
|
||||||
) # Add routes for Events in nested runtimes
|
) # Add routes for Events in nested runtimes
|
||||||
|
|
||||||
|
|
||||||
base_app.add_middleware(
|
base_app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=PERMITTED_CORS_ORIGINS,
|
allow_origins=PERMITTED_CORS_ORIGINS,
|
||||||
|
|||||||
20
enterprise/server/sharing/README.md
Normal file
20
enterprise/server/sharing/README.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Sharing Package
|
||||||
|
|
||||||
|
This package contains functionality for sharing conversations.
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
- **shared.py**: Data models for shared conversations
|
||||||
|
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
|
||||||
|
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
|
||||||
|
- **shared_event_service.py**: Service interface for accessing shared events
|
||||||
|
- **shared_event_service_impl.py**: Implementation of the shared event service
|
||||||
|
- **shared_conversation_router.py**: REST API endpoints for shared conversations
|
||||||
|
- **shared_event_router.py**: REST API endpoints for shared events
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Read-only access to shared conversations
|
||||||
|
- Event access for shared conversations
|
||||||
|
- Search and filtering capabilities
|
||||||
|
- Pagination support
|
||||||
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
142
enterprise/server/sharing/filesystem_shared_event_service.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Implementation of SharedEventService.
|
||||||
|
|
||||||
|
This implementation provides read-only access to events from shared conversations:
|
||||||
|
- Validates that the conversation is shared before returning events
|
||||||
|
- Uses existing EventService for actual event retrieval
|
||||||
|
- Uses SharedConversationInfoService for shared conversation validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from server.sharing.shared_conversation_info_service import (
|
||||||
|
SharedConversationInfoService,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_event_service import (
|
||||||
|
SharedEventService,
|
||||||
|
SharedEventServiceInjector,
|
||||||
|
)
|
||||||
|
from server.sharing.sql_shared_conversation_info_service import (
|
||||||
|
SQLSharedConversationInfoService,
|
||||||
|
)
|
||||||
|
|
||||||
|
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||||
|
from openhands.app_server.event.event_service import EventService
|
||||||
|
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||||
|
from openhands.app_server.services.injector import InjectorState
|
||||||
|
from openhands.sdk import Event
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SharedEventServiceImpl(SharedEventService):
|
||||||
|
"""Implementation of SharedEventService that validates shared access."""
|
||||||
|
|
||||||
|
shared_conversation_info_service: SharedConversationInfoService
|
||||||
|
event_service: EventService
|
||||||
|
|
||||||
|
async def get_shared_event(
|
||||||
|
self, conversation_id: UUID, event_id: str
|
||||||
|
) -> Event | None:
|
||||||
|
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||||
|
# First check if the conversation is shared
|
||||||
|
shared_conversation_info = (
|
||||||
|
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if shared_conversation_info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If conversation is shared, get the event
|
||||||
|
return await self.event_service.get_event(event_id)
|
||||||
|
|
||||||
|
async def search_shared_events(
|
||||||
|
self,
|
||||||
|
conversation_id: UUID,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||||
|
page_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> EventPage:
|
||||||
|
"""Search events for a specific shared conversation."""
|
||||||
|
# First check if the conversation is shared
|
||||||
|
shared_conversation_info = (
|
||||||
|
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if shared_conversation_info is None:
|
||||||
|
# Return empty page if conversation is not shared
|
||||||
|
return EventPage(items=[], next_page_id=None)
|
||||||
|
|
||||||
|
# If conversation is shared, search events for this conversation
|
||||||
|
return await self.event_service.search_events(
|
||||||
|
conversation_id__eq=conversation_id,
|
||||||
|
kind__eq=kind__eq,
|
||||||
|
timestamp__gte=timestamp__gte,
|
||||||
|
timestamp__lt=timestamp__lt,
|
||||||
|
sort_order=sort_order,
|
||||||
|
page_id=page_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def count_shared_events(
|
||||||
|
self,
|
||||||
|
conversation_id: UUID,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||||
|
) -> int:
|
||||||
|
"""Count events for a specific shared conversation."""
|
||||||
|
# First check if the conversation is shared
|
||||||
|
shared_conversation_info = (
|
||||||
|
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if shared_conversation_info is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# If conversation is shared, count events for this conversation
|
||||||
|
return await self.event_service.count_events(
|
||||||
|
conversation_id__eq=conversation_id,
|
||||||
|
kind__eq=kind__eq,
|
||||||
|
timestamp__gte=timestamp__gte,
|
||||||
|
timestamp__lt=timestamp__lt,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedEventServiceImplInjector(SharedEventServiceInjector):
|
||||||
|
async def inject(
|
||||||
|
self, state: InjectorState, request: Request | None = None
|
||||||
|
) -> AsyncGenerator[SharedEventService, None]:
|
||||||
|
# Define inline to prevent circular lookup
|
||||||
|
from openhands.app_server.config import (
|
||||||
|
get_db_session,
|
||||||
|
get_event_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
get_db_session(state, request) as db_session,
|
||||||
|
get_event_service(state, request) as event_service,
|
||||||
|
):
|
||||||
|
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||||
|
db_session=db_session
|
||||||
|
)
|
||||||
|
service = SharedEventServiceImpl(
|
||||||
|
shared_conversation_info_service=shared_conversation_info_service,
|
||||||
|
event_service=event_service,
|
||||||
|
)
|
||||||
|
yield service
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
import asyncio
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from server.sharing.shared_conversation_models import (
|
||||||
|
SharedConversation,
|
||||||
|
SharedConversationPage,
|
||||||
|
SharedConversationSortOrder,
|
||||||
|
)
|
||||||
|
|
||||||
|
from openhands.app_server.services.injector import Injector
|
||||||
|
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversationInfoService(ABC):
|
||||||
|
"""Service for accessing shared conversation info without user restrictions."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search_shared_conversation_info(
|
||||||
|
self,
|
||||||
|
title__contains: str | None = None,
|
||||||
|
created_at__gte: datetime | None = None,
|
||||||
|
created_at__lt: datetime | None = None,
|
||||||
|
updated_at__gte: datetime | None = None,
|
||||||
|
updated_at__lt: datetime | None = None,
|
||||||
|
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||||
|
page_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
include_sub_conversations: bool = False,
|
||||||
|
) -> SharedConversationPage:
|
||||||
|
"""Search for shared conversations."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def count_shared_conversation_info(
|
||||||
|
self,
|
||||||
|
title__contains: str | None = None,
|
||||||
|
created_at__gte: datetime | None = None,
|
||||||
|
created_at__lt: datetime | None = None,
|
||||||
|
updated_at__gte: datetime | None = None,
|
||||||
|
updated_at__lt: datetime | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count shared conversations."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_shared_conversation_info(
|
||||||
|
self, conversation_id: UUID
|
||||||
|
) -> SharedConversation | None:
|
||||||
|
"""Get a single shared conversation info, returning None if missing or not shared."""
|
||||||
|
|
||||||
|
async def batch_get_shared_conversation_info(
|
||||||
|
self, conversation_ids: list[UUID]
|
||||||
|
) -> list[SharedConversation | None]:
|
||||||
|
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.get_shared_conversation_info(conversation_id)
|
||||||
|
for conversation_id in conversation_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversationInfoServiceInjector(
|
||||||
|
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
|
||||||
|
):
|
||||||
|
pass
|
||||||
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
56
enterprise/server/sharing/shared_conversation_models.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
# Simplified imports to avoid dependency chain issues
|
||||||
|
# from openhands.integrations.service_types import ProviderType
|
||||||
|
# from openhands.sdk.llm import MetricsSnapshot
|
||||||
|
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
# For now, use Any to avoid import issues
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from openhands.agent_server.utils import OpenHandsUUID, utc_now
|
||||||
|
|
||||||
|
ProviderType = Any
|
||||||
|
MetricsSnapshot = Any
|
||||||
|
ConversationTrigger = Any
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversation(BaseModel):
|
||||||
|
"""Shared conversation info model with all fields from AppConversationInfo."""
|
||||||
|
|
||||||
|
id: OpenHandsUUID = Field(default_factory=uuid4)
|
||||||
|
|
||||||
|
created_by_user_id: str | None
|
||||||
|
sandbox_id: str
|
||||||
|
|
||||||
|
selected_repository: str | None = None
|
||||||
|
selected_branch: str | None = None
|
||||||
|
git_provider: ProviderType | None = None
|
||||||
|
title: str | None = None
|
||||||
|
pr_number: list[int] = Field(default_factory=list)
|
||||||
|
llm_model: str | None = None
|
||||||
|
|
||||||
|
metrics: MetricsSnapshot | None = None
|
||||||
|
|
||||||
|
parent_conversation_id: OpenHandsUUID | None = None
|
||||||
|
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||||
|
|
||||||
|
created_at: datetime = Field(default_factory=utc_now)
|
||||||
|
updated_at: datetime = Field(default_factory=utc_now)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversationSortOrder(Enum):
|
||||||
|
CREATED_AT = 'CREATED_AT'
|
||||||
|
CREATED_AT_DESC = 'CREATED_AT_DESC'
|
||||||
|
UPDATED_AT = 'UPDATED_AT'
|
||||||
|
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
|
||||||
|
TITLE = 'TITLE'
|
||||||
|
TITLE_DESC = 'TITLE_DESC'
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversationPage(BaseModel):
|
||||||
|
items: list[SharedConversation]
|
||||||
|
next_page_id: str | None = None
|
||||||
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
135
enterprise/server/sharing/shared_conversation_router.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Shared Conversation router for OpenHands Server."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from server.sharing.shared_conversation_info_service import (
|
||||||
|
SharedConversationInfoService,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_conversation_models import (
|
||||||
|
SharedConversation,
|
||||||
|
SharedConversationPage,
|
||||||
|
SharedConversationSortOrder,
|
||||||
|
)
|
||||||
|
from server.sharing.sql_shared_conversation_info_service import (
|
||||||
|
SQLSharedConversationInfoServiceInjector,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
|
||||||
|
shared_conversation_info_service_dependency = Depends(
|
||||||
|
SQLSharedConversationInfoServiceInjector().depends
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read methods
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/search')
|
||||||
|
async def search_shared_conversations(
|
||||||
|
title__contains: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(title='Filter by title containing this string'),
|
||||||
|
] = None,
|
||||||
|
created_at__gte: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||||
|
] = None,
|
||||||
|
created_at__lt: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by created_at less than this datetime'),
|
||||||
|
] = None,
|
||||||
|
updated_at__gte: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||||
|
] = None,
|
||||||
|
updated_at__lt: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by updated_at less than this datetime'),
|
||||||
|
] = None,
|
||||||
|
sort_order: Annotated[
|
||||||
|
SharedConversationSortOrder,
|
||||||
|
Query(title='Sort order for results'),
|
||||||
|
] = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||||
|
page_id: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(title='Optional next_page_id from the previously returned page'),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(
|
||||||
|
title='The max number of results in the page',
|
||||||
|
gt=0,
|
||||||
|
lte=100,
|
||||||
|
),
|
||||||
|
] = 100,
|
||||||
|
include_sub_conversations: Annotated[
|
||||||
|
bool,
|
||||||
|
Query(
|
||||||
|
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||||
|
) -> SharedConversationPage:
|
||||||
|
"""Search / List shared conversations."""
|
||||||
|
assert limit > 0
|
||||||
|
assert limit <= 100
|
||||||
|
return await shared_conversation_service.search_shared_conversation_info(
|
||||||
|
title__contains=title__contains,
|
||||||
|
created_at__gte=created_at__gte,
|
||||||
|
created_at__lt=created_at__lt,
|
||||||
|
updated_at__gte=updated_at__gte,
|
||||||
|
updated_at__lt=updated_at__lt,
|
||||||
|
sort_order=sort_order,
|
||||||
|
page_id=page_id,
|
||||||
|
limit=limit,
|
||||||
|
include_sub_conversations=include_sub_conversations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/count')
|
||||||
|
async def count_shared_conversations(
|
||||||
|
title__contains: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(title='Filter by title containing this string'),
|
||||||
|
] = None,
|
||||||
|
created_at__gte: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by created_at greater than or equal to this datetime'),
|
||||||
|
] = None,
|
||||||
|
created_at__lt: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by created_at less than this datetime'),
|
||||||
|
] = None,
|
||||||
|
updated_at__gte: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by updated_at greater than or equal to this datetime'),
|
||||||
|
] = None,
|
||||||
|
updated_at__lt: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Filter by updated_at less than this datetime'),
|
||||||
|
] = None,
|
||||||
|
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||||
|
) -> int:
|
||||||
|
"""Count shared conversations matching the given filters."""
|
||||||
|
return await shared_conversation_service.count_shared_conversation_info(
|
||||||
|
title__contains=title__contains,
|
||||||
|
created_at__gte=created_at__gte,
|
||||||
|
created_at__lt=created_at__lt,
|
||||||
|
updated_at__gte=updated_at__gte,
|
||||||
|
updated_at__lt=updated_at__lt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('')
|
||||||
|
async def batch_get_shared_conversations(
|
||||||
|
ids: Annotated[list[str], Query()],
|
||||||
|
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||||
|
) -> list[SharedConversation | None]:
|
||||||
|
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||||
|
assert len(ids) <= 100
|
||||||
|
uuids = [UUID(id_) for id_ in ids]
|
||||||
|
shared_conversation_info = (
|
||||||
|
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||||
|
)
|
||||||
|
return shared_conversation_info
|
||||||
126
enterprise/server/sharing/shared_event_router.py
Normal file
126
enterprise/server/sharing/shared_event_router.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""Shared Event router for OpenHands Server."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from server.sharing.filesystem_shared_event_service import (
|
||||||
|
SharedEventServiceImplInjector,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_event_service import SharedEventService
|
||||||
|
|
||||||
|
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||||
|
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||||
|
from openhands.sdk import Event
|
||||||
|
|
||||||
|
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
|
||||||
|
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
|
||||||
|
|
||||||
|
|
||||||
|
# Read methods
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/search')
|
||||||
|
async def search_shared_events(
|
||||||
|
conversation_id: Annotated[
|
||||||
|
str,
|
||||||
|
Query(title='Conversation ID to search events for'),
|
||||||
|
],
|
||||||
|
kind__eq: Annotated[
|
||||||
|
EventKind | None,
|
||||||
|
Query(title='Optional filter by event kind'),
|
||||||
|
] = None,
|
||||||
|
timestamp__gte: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||||
|
] = None,
|
||||||
|
timestamp__lt: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Optional filter by timestamp less than'),
|
||||||
|
] = None,
|
||||||
|
sort_order: Annotated[
|
||||||
|
EventSortOrder,
|
||||||
|
Query(title='Sort order for results'),
|
||||||
|
] = EventSortOrder.TIMESTAMP,
|
||||||
|
page_id: Annotated[
|
||||||
|
str | None,
|
||||||
|
Query(title='Optional next_page_id from the previously returned page'),
|
||||||
|
] = None,
|
||||||
|
limit: Annotated[
|
||||||
|
int,
|
||||||
|
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||||
|
] = 100,
|
||||||
|
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||||
|
) -> EventPage:
|
||||||
|
"""Search / List events for a shared conversation."""
|
||||||
|
assert limit > 0
|
||||||
|
assert limit <= 100
|
||||||
|
return await shared_event_service.search_shared_events(
|
||||||
|
conversation_id=UUID(conversation_id),
|
||||||
|
kind__eq=kind__eq,
|
||||||
|
timestamp__gte=timestamp__gte,
|
||||||
|
timestamp__lt=timestamp__lt,
|
||||||
|
sort_order=sort_order,
|
||||||
|
page_id=page_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/count')
|
||||||
|
async def count_shared_events(
|
||||||
|
conversation_id: Annotated[
|
||||||
|
str,
|
||||||
|
Query(title='Conversation ID to count events for'),
|
||||||
|
],
|
||||||
|
kind__eq: Annotated[
|
||||||
|
EventKind | None,
|
||||||
|
Query(title='Optional filter by event kind'),
|
||||||
|
] = None,
|
||||||
|
timestamp__gte: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Optional filter by timestamp greater than or equal to'),
|
||||||
|
] = None,
|
||||||
|
timestamp__lt: Annotated[
|
||||||
|
datetime | None,
|
||||||
|
Query(title='Optional filter by timestamp less than'),
|
||||||
|
] = None,
|
||||||
|
sort_order: Annotated[
|
||||||
|
EventSortOrder,
|
||||||
|
Query(title='Sort order for results'),
|
||||||
|
] = EventSortOrder.TIMESTAMP,
|
||||||
|
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||||
|
) -> int:
|
||||||
|
"""Count events for a shared conversation matching the given filters."""
|
||||||
|
return await shared_event_service.count_shared_events(
|
||||||
|
conversation_id=UUID(conversation_id),
|
||||||
|
kind__eq=kind__eq,
|
||||||
|
timestamp__gte=timestamp__gte,
|
||||||
|
timestamp__lt=timestamp__lt,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('')
|
||||||
|
async def batch_get_shared_events(
|
||||||
|
conversation_id: Annotated[
|
||||||
|
UUID,
|
||||||
|
Query(title='Conversation ID to get events for'),
|
||||||
|
],
|
||||||
|
id: Annotated[list[str], Query()],
|
||||||
|
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||||
|
) -> list[Event | None]:
|
||||||
|
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||||
|
assert len(id) <= 100
|
||||||
|
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/{conversation_id}/{event_id}')
|
||||||
|
async def get_shared_event(
|
||||||
|
conversation_id: UUID,
|
||||||
|
event_id: str,
|
||||||
|
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||||
|
) -> Event | None:
|
||||||
|
"""Get a single event from a shared conversation by conversation_id and event_id."""
|
||||||
|
return await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||||
64
enterprise/server/sharing/shared_event_service.py
Normal file
64
enterprise/server/sharing/shared_event_service.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||||
|
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||||
|
from openhands.app_server.services.injector import Injector
|
||||||
|
from openhands.sdk import Event
|
||||||
|
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedEventService(ABC):
|
||||||
|
"""Event Service for getting events from shared conversations only."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_shared_event(
|
||||||
|
self, conversation_id: UUID, event_id: str
|
||||||
|
) -> Event | None:
|
||||||
|
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search_shared_events(
|
||||||
|
self,
|
||||||
|
conversation_id: UUID,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||||
|
page_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> EventPage:
|
||||||
|
"""Search events for a specific shared conversation."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def count_shared_events(
|
||||||
|
self,
|
||||||
|
conversation_id: UUID,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||||
|
) -> int:
|
||||||
|
"""Count events for a specific shared conversation."""
|
||||||
|
|
||||||
|
async def batch_get_shared_events(
|
||||||
|
self, conversation_id: UUID, event_ids: list[str]
|
||||||
|
) -> list[Event | None]:
|
||||||
|
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[
|
||||||
|
self.get_shared_event(conversation_id, event_id)
|
||||||
|
for event_id in event_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedEventServiceInjector(
|
||||||
|
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
|
||||||
|
):
|
||||||
|
pass
|
||||||
@@ -0,0 +1,282 @@
|
|||||||
|
"""SQL implementation of SharedConversationInfoService.
|
||||||
|
|
||||||
|
This implementation provides read-only access to shared conversations:
|
||||||
|
- Direct database access without user permission checks
|
||||||
|
- Filters only conversations marked as shared (currently public)
|
||||||
|
- Full async/await support using SQL async db_sessions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from server.sharing.shared_conversation_info_service import (
|
||||||
|
SharedConversationInfoService,
|
||||||
|
SharedConversationInfoServiceInjector,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_conversation_models import (
|
||||||
|
SharedConversation,
|
||||||
|
SharedConversationPage,
|
||||||
|
SharedConversationSortOrder,
|
||||||
|
)
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||||
|
StoredConversationMetadata,
|
||||||
|
)
|
||||||
|
from openhands.app_server.services.injector import InjectorState
|
||||||
|
from openhands.integrations.provider import ProviderType
|
||||||
|
from openhands.sdk.llm import MetricsSnapshot
|
||||||
|
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||||
|
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
|
||||||
|
|
||||||
|
db_session: AsyncSession
|
||||||
|
|
||||||
|
async def search_shared_conversation_info(
|
||||||
|
self,
|
||||||
|
title__contains: str | None = None,
|
||||||
|
created_at__gte: datetime | None = None,
|
||||||
|
created_at__lt: datetime | None = None,
|
||||||
|
updated_at__gte: datetime | None = None,
|
||||||
|
updated_at__lt: datetime | None = None,
|
||||||
|
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
|
||||||
|
page_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
include_sub_conversations: bool = False,
|
||||||
|
) -> SharedConversationPage:
|
||||||
|
"""Search for shared conversations."""
|
||||||
|
query = self._public_select()
|
||||||
|
|
||||||
|
# Conditionally exclude sub-conversations based on the parameter
|
||||||
|
if not include_sub_conversations:
|
||||||
|
# Exclude sub-conversations (only include top-level conversations)
|
||||||
|
query = query.where(
|
||||||
|
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = self._apply_filters(
|
||||||
|
query=query,
|
||||||
|
title__contains=title__contains,
|
||||||
|
created_at__gte=created_at__gte,
|
||||||
|
created_at__lt=created_at__lt,
|
||||||
|
updated_at__gte=updated_at__gte,
|
||||||
|
updated_at__lt=updated_at__lt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add sort order
|
||||||
|
if sort_order == SharedConversationSortOrder.CREATED_AT:
|
||||||
|
query = query.order_by(StoredConversationMetadata.created_at)
|
||||||
|
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
|
||||||
|
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||||
|
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
|
||||||
|
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||||
|
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
|
||||||
|
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||||
|
elif sort_order == SharedConversationSortOrder.TITLE:
|
||||||
|
query = query.order_by(StoredConversationMetadata.title)
|
||||||
|
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
|
||||||
|
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
if page_id is not None:
|
||||||
|
try:
|
||||||
|
offset = int(page_id)
|
||||||
|
query = query.offset(offset)
|
||||||
|
except ValueError:
|
||||||
|
# If page_id is not a valid integer, start from beginning
|
||||||
|
offset = 0
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
# Apply limit and get one extra to check if there are more results
|
||||||
|
query = query.limit(limit + 1)
|
||||||
|
|
||||||
|
result = await self.db_session.execute(query)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
# Check if there are more results
|
||||||
|
has_more = len(rows) > limit
|
||||||
|
if has_more:
|
||||||
|
rows = rows[:limit]
|
||||||
|
|
||||||
|
items = [self._to_shared_conversation(row) for row in rows]
|
||||||
|
|
||||||
|
# Calculate next page ID
|
||||||
|
next_page_id = None
|
||||||
|
if has_more:
|
||||||
|
next_page_id = str(offset + limit)
|
||||||
|
|
||||||
|
return SharedConversationPage(items=items, next_page_id=next_page_id)
|
||||||
|
|
||||||
|
async def count_shared_conversation_info(
|
||||||
|
self,
|
||||||
|
title__contains: str | None = None,
|
||||||
|
created_at__gte: datetime | None = None,
|
||||||
|
created_at__lt: datetime | None = None,
|
||||||
|
updated_at__gte: datetime | None = None,
|
||||||
|
updated_at__lt: datetime | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count shared conversations matching the given filters."""
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
query = select(func.count(StoredConversationMetadata.conversation_id))
|
||||||
|
# Only include shared conversations
|
||||||
|
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||||
|
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||||
|
|
||||||
|
query = self._apply_filters(
|
||||||
|
query=query,
|
||||||
|
title__contains=title__contains,
|
||||||
|
created_at__gte=created_at__gte,
|
||||||
|
created_at__lt=created_at__lt,
|
||||||
|
updated_at__gte=updated_at__gte,
|
||||||
|
updated_at__lt=updated_at__lt,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.db_session.execute(query)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_shared_conversation_info(
|
||||||
|
self, conversation_id: UUID
|
||||||
|
) -> SharedConversation | None:
|
||||||
|
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||||
|
query = self._public_select().where(
|
||||||
|
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.db_session.execute(query)
|
||||||
|
stored = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if stored is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._to_shared_conversation(stored)
|
||||||
|
|
||||||
|
def _public_select(self):
|
||||||
|
"""Create a select query that only returns public conversations."""
|
||||||
|
query = select(StoredConversationMetadata).where(
|
||||||
|
StoredConversationMetadata.conversation_version == 'V1'
|
||||||
|
)
|
||||||
|
# Only include conversations marked as public
|
||||||
|
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||||
|
return query
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
title__contains: str | None = None,
|
||||||
|
created_at__gte: datetime | None = None,
|
||||||
|
created_at__lt: datetime | None = None,
|
||||||
|
updated_at__gte: datetime | None = None,
|
||||||
|
updated_at__lt: datetime | None = None,
|
||||||
|
):
|
||||||
|
"""Apply common filters to a query."""
|
||||||
|
if title__contains is not None:
|
||||||
|
query = query.where(
|
||||||
|
StoredConversationMetadata.title.contains(title__contains)
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_at__gte is not None:
|
||||||
|
query = query.where(
|
||||||
|
StoredConversationMetadata.created_at >= created_at__gte
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_at__lt is not None:
|
||||||
|
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
|
||||||
|
|
||||||
|
if updated_at__gte is not None:
|
||||||
|
query = query.where(
|
||||||
|
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_at__lt is not None:
|
||||||
|
query = query.where(
|
||||||
|
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||||
|
)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
def _to_shared_conversation(
|
||||||
|
self,
|
||||||
|
stored: StoredConversationMetadata,
|
||||||
|
sub_conversation_ids: list[UUID] | None = None,
|
||||||
|
) -> SharedConversation:
|
||||||
|
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||||
|
# V1 conversations should always have a sandbox_id
|
||||||
|
sandbox_id = stored.sandbox_id
|
||||||
|
assert sandbox_id is not None
|
||||||
|
|
||||||
|
# Rebuild token usage
|
||||||
|
token_usage = TokenUsage(
|
||||||
|
prompt_tokens=stored.prompt_tokens,
|
||||||
|
completion_tokens=stored.completion_tokens,
|
||||||
|
cache_read_tokens=stored.cache_read_tokens,
|
||||||
|
cache_write_tokens=stored.cache_write_tokens,
|
||||||
|
context_window=stored.context_window,
|
||||||
|
per_turn_token=stored.per_turn_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rebuild metrics object
|
||||||
|
metrics = MetricsSnapshot(
|
||||||
|
accumulated_cost=stored.accumulated_cost,
|
||||||
|
max_budget_per_task=stored.max_budget_per_task,
|
||||||
|
accumulated_token_usage=token_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get timestamps
|
||||||
|
created_at = self._fix_timezone(stored.created_at)
|
||||||
|
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||||
|
|
||||||
|
return SharedConversation(
|
||||||
|
id=UUID(stored.conversation_id),
|
||||||
|
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||||
|
sandbox_id=stored.sandbox_id,
|
||||||
|
selected_repository=stored.selected_repository,
|
||||||
|
selected_branch=stored.selected_branch,
|
||||||
|
git_provider=(
|
||||||
|
ProviderType(stored.git_provider) if stored.git_provider else None
|
||||||
|
),
|
||||||
|
title=stored.title,
|
||||||
|
pr_number=stored.pr_number,
|
||||||
|
llm_model=stored.llm_model,
|
||||||
|
metrics=metrics,
|
||||||
|
parent_conversation_id=(
|
||||||
|
UUID(stored.parent_conversation_id)
|
||||||
|
if stored.parent_conversation_id
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
sub_conversation_ids=sub_conversation_ids or [],
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fix_timezone(self, value: datetime) -> datetime:
|
||||||
|
"""Sqlite does not store timezones - and since we can't update the existing models
|
||||||
|
we assume UTC if the timezone is missing."""
|
||||||
|
if not value.tzinfo:
|
||||||
|
value = value.replace(tzinfo=UTC)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
|
||||||
|
async def inject(
|
||||||
|
self, state: InjectorState, request: Request | None = None
|
||||||
|
) -> AsyncGenerator[SharedConversationInfoService, None]:
|
||||||
|
# Define inline to prevent circular lookup
|
||||||
|
from openhands.app_server.config import get_db_session
|
||||||
|
|
||||||
|
async with get_db_session(state, request) as db_session:
|
||||||
|
service = SQLSharedConversationInfoService(db_session=db_session)
|
||||||
|
yield service
|
||||||
@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
|
|||||||
kwargs.pop('context_window', None)
|
kwargs.pop('context_window', None)
|
||||||
kwargs.pop('per_turn_token', None)
|
kwargs.pop('per_turn_token', None)
|
||||||
kwargs.pop('parent_conversation_id', None)
|
kwargs.pop('parent_conversation_id', None)
|
||||||
|
kwargs.pop('public')
|
||||||
|
|
||||||
return ConversationMetadata(**kwargs)
|
return ConversationMetadata(**kwargs)
|
||||||
|
|
||||||
|
|||||||
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
1
enterprise/tests/unit/test_sharing/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for sharing package."""
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
"""Tests for public conversation models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from server.sharing.shared_conversation_models import (
|
||||||
|
SharedConversation,
|
||||||
|
SharedConversationPage,
|
||||||
|
SharedConversationSortOrder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_conversation_creation():
|
||||||
|
"""Test that SharedConversation can be created with all required fields."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
conversation = SharedConversation(
|
||||||
|
id=conversation_id,
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox',
|
||||||
|
title='Test Conversation',
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
selected_repository=None,
|
||||||
|
parent_conversation_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.id == conversation_id
|
||||||
|
assert conversation.title == 'Test Conversation'
|
||||||
|
assert conversation.created_by_user_id == 'test_user'
|
||||||
|
assert conversation.sandbox_id == 'test_sandbox'
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_conversation_page_creation():
|
||||||
|
"""Test that SharedConversationPage can be created."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
conversation = SharedConversation(
|
||||||
|
id=conversation_id,
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox',
|
||||||
|
title='Test Conversation',
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
selected_repository=None,
|
||||||
|
parent_conversation_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
page = SharedConversationPage(
|
||||||
|
items=[conversation],
|
||||||
|
next_page_id='next_page',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(page.items) == 1
|
||||||
|
assert page.items[0].id == conversation_id
|
||||||
|
assert page.next_page_id == 'next_page'
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_conversation_sort_order_enum():
|
||||||
|
"""Test that SharedConversationSortOrder enum has expected values."""
|
||||||
|
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
|
||||||
|
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
|
||||||
|
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
|
||||||
|
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
|
||||||
|
assert hasattr(SharedConversationSortOrder, 'TITLE')
|
||||||
|
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_conversation_optional_fields():
|
||||||
|
"""Test that SharedConversation works with optional fields."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
parent_id = uuid4()
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
conversation = SharedConversation(
|
||||||
|
id=conversation_id,
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox',
|
||||||
|
title='Test Conversation',
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
selected_repository='owner/repo',
|
||||||
|
parent_conversation_id=parent_id,
|
||||||
|
llm_model='gpt-4',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.selected_repository == 'owner/repo'
|
||||||
|
assert conversation.parent_conversation_id == parent_id
|
||||||
|
assert conversation.llm_model == 'gpt-4'
|
||||||
@@ -0,0 +1,430 @@
|
|||||||
|
"""Tests for SharedConversationInfoService."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from server.sharing.shared_conversation_models import (
|
||||||
|
SharedConversationSortOrder,
|
||||||
|
)
|
||||||
|
from server.sharing.sql_shared_conversation_info_service import (
|
||||||
|
SQLSharedConversationInfoService,
|
||||||
|
)
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||||
|
AppConversationInfo,
|
||||||
|
)
|
||||||
|
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||||
|
SQLAppConversationInfoService,
|
||||||
|
)
|
||||||
|
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||||
|
from openhands.app_server.utils.sql_utils import Base
|
||||||
|
from openhands.integrations.provider import ProviderType
|
||||||
|
from openhands.sdk.llm import MetricsSnapshot
|
||||||
|
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||||
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""Create an async SQLite engine for testing."""
|
||||||
|
engine = create_async_engine(
|
||||||
|
'sqlite+aiosqlite:///:memory:',
|
||||||
|
poolclass=StaticPool,
|
||||||
|
connect_args={'check_same_thread': False},
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create all tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create an async session for testing."""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
async with async_session_maker() as db_session:
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def shared_conversation_info_service(async_session):
|
||||||
|
"""Create a SharedConversationInfoService for testing."""
|
||||||
|
return SQLSharedConversationInfoService(db_session=async_session)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def app_conversation_service(async_session):
|
||||||
|
"""Create an AppConversationInfoService for creating test data."""
|
||||||
|
return SQLAppConversationInfoService(
|
||||||
|
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_conversation_info():
|
||||||
|
"""Create a sample conversation info for testing."""
|
||||||
|
return AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox',
|
||||||
|
selected_repository='test/repo',
|
||||||
|
selected_branch='main',
|
||||||
|
git_provider=ProviderType.GITHUB,
|
||||||
|
title='Test Conversation',
|
||||||
|
trigger=ConversationTrigger.GUI,
|
||||||
|
pr_number=[123],
|
||||||
|
llm_model='gpt-4',
|
||||||
|
metrics=MetricsSnapshot(
|
||||||
|
accumulated_cost=1.5,
|
||||||
|
max_budget_per_task=10.0,
|
||||||
|
accumulated_token_usage=TokenUsage(
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=4096,
|
||||||
|
per_turn_token=150,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
parent_conversation_id=None,
|
||||||
|
sub_conversation_ids=[],
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
public=True, # Make it public for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_private_conversation_info():
|
||||||
|
"""Create a sample private conversation info for testing."""
|
||||||
|
return AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox_private',
|
||||||
|
selected_repository='test/private_repo',
|
||||||
|
selected_branch='main',
|
||||||
|
git_provider=ProviderType.GITHUB,
|
||||||
|
title='Private Conversation',
|
||||||
|
trigger=ConversationTrigger.GUI,
|
||||||
|
pr_number=[124],
|
||||||
|
llm_model='gpt-4',
|
||||||
|
metrics=MetricsSnapshot(
|
||||||
|
accumulated_cost=2.0,
|
||||||
|
max_budget_per_task=10.0,
|
||||||
|
accumulated_token_usage=TokenUsage(
|
||||||
|
prompt_tokens=200,
|
||||||
|
completion_tokens=100,
|
||||||
|
cache_read_tokens=0,
|
||||||
|
cache_write_tokens=0,
|
||||||
|
context_window=4096,
|
||||||
|
per_turn_token=300,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
parent_conversation_id=None,
|
||||||
|
sub_conversation_ids=[],
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
public=False, # Make it private
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedConversationInfoService:
|
||||||
|
"""Test cases for SharedConversationInfoService."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_shared_conversation_info_returns_public_conversation(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
sample_conversation_info,
|
||||||
|
):
|
||||||
|
"""Test that get_shared_conversation_info returns a public conversation."""
|
||||||
|
# Create a public conversation
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_conversation_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retrieve it via public service
|
||||||
|
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||||
|
sample_conversation_info.id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == sample_conversation_info.id
|
||||||
|
assert result.title == sample_conversation_info.title
|
||||||
|
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
sample_private_conversation_info,
|
||||||
|
):
|
||||||
|
"""Test that get_shared_conversation_info returns None for private conversations."""
|
||||||
|
# Create a private conversation
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_private_conversation_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to retrieve it via public service
|
||||||
|
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||||
|
sample_private_conversation_info.id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
|
||||||
|
self, shared_conversation_info_service
|
||||||
|
):
|
||||||
|
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
|
||||||
|
nonexistent_id = uuid4()
|
||||||
|
result = await shared_conversation_info_service.get_shared_conversation_info(
|
||||||
|
nonexistent_id
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_shared_conversation_info_returns_only_public_conversations(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
sample_conversation_info,
|
||||||
|
sample_private_conversation_info,
|
||||||
|
):
|
||||||
|
"""Test that search only returns public conversations."""
|
||||||
|
# Create both public and private conversations
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_conversation_info
|
||||||
|
)
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_private_conversation_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for all conversations
|
||||||
|
result = (
|
||||||
|
await shared_conversation_info_service.search_shared_conversation_info()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the public conversation
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.items[0].id == sample_conversation_info.id
|
||||||
|
assert result.items[0].title == sample_conversation_info.title
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_shared_conversation_info_with_title_filter(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
sample_conversation_info,
|
||||||
|
):
|
||||||
|
"""Test searching with title filter."""
|
||||||
|
# Create a public conversation
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_conversation_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search with matching title
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
title__contains='Test'
|
||||||
|
)
|
||||||
|
assert len(result.items) == 1
|
||||||
|
|
||||||
|
# Search with non-matching title
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
title__contains='NonExistent'
|
||||||
|
)
|
||||||
|
assert len(result.items) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_shared_conversation_info_with_sort_order(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
):
|
||||||
|
"""Test searching with different sort orders."""
|
||||||
|
# Create multiple public conversations with different titles and timestamps
|
||||||
|
conv1 = AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox_1',
|
||||||
|
title='A First Conversation',
|
||||||
|
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||||
|
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||||
|
public=True,
|
||||||
|
metrics=MetricsSnapshot(
|
||||||
|
accumulated_cost=0.0,
|
||||||
|
max_budget_per_task=10.0,
|
||||||
|
accumulated_token_usage=TokenUsage(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conv2 = AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox_2',
|
||||||
|
title='B Second Conversation',
|
||||||
|
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||||
|
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||||
|
public=True,
|
||||||
|
metrics=MetricsSnapshot(
|
||||||
|
accumulated_cost=0.0,
|
||||||
|
max_budget_per_task=10.0,
|
||||||
|
accumulated_token_usage=TokenUsage(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await app_conversation_service.save_app_conversation_info(conv1)
|
||||||
|
await app_conversation_service.save_app_conversation_info(conv2)
|
||||||
|
|
||||||
|
# Test sort by title ascending
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
sort_order=SharedConversationSortOrder.TITLE
|
||||||
|
)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.items[0].title == 'A First Conversation'
|
||||||
|
assert result.items[1].title == 'B Second Conversation'
|
||||||
|
|
||||||
|
# Test sort by title descending
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
sort_order=SharedConversationSortOrder.TITLE_DESC
|
||||||
|
)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.items[0].title == 'B Second Conversation'
|
||||||
|
assert result.items[1].title == 'A First Conversation'
|
||||||
|
|
||||||
|
# Test sort by created_at ascending
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||||
|
)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.items[0].id == conv1.id
|
||||||
|
assert result.items[1].id == conv2.id
|
||||||
|
|
||||||
|
# Test sort by created_at descending (default)
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
|
||||||
|
)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.items[0].id == conv2.id
|
||||||
|
assert result.items[1].id == conv1.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_shared_conversation_info(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
sample_conversation_info,
|
||||||
|
sample_private_conversation_info,
|
||||||
|
):
|
||||||
|
"""Test counting public conversations."""
|
||||||
|
# Initially should be 0
|
||||||
|
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
# Create a public conversation
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_conversation_info
|
||||||
|
)
|
||||||
|
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# Create a private conversation - count should remain 1
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_private_conversation_info
|
||||||
|
)
|
||||||
|
count = await shared_conversation_info_service.count_shared_conversation_info()
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_get_shared_conversation_info(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
sample_conversation_info,
|
||||||
|
sample_private_conversation_info,
|
||||||
|
):
|
||||||
|
"""Test batch getting public conversations."""
|
||||||
|
# Create both public and private conversations
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_conversation_info
|
||||||
|
)
|
||||||
|
await app_conversation_service.save_app_conversation_info(
|
||||||
|
sample_private_conversation_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch get both conversations
|
||||||
|
result = (
|
||||||
|
await shared_conversation_info_service.batch_get_shared_conversation_info(
|
||||||
|
[sample_conversation_info.id, sample_private_conversation_info.id]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return the public one and None for the private one
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] is not None
|
||||||
|
assert result[0].id == sample_conversation_info.id
|
||||||
|
assert result[1] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_pagination(
|
||||||
|
self,
|
||||||
|
shared_conversation_info_service,
|
||||||
|
app_conversation_service,
|
||||||
|
):
|
||||||
|
"""Test search with pagination."""
|
||||||
|
# Create multiple public conversations
|
||||||
|
conversations = []
|
||||||
|
for i in range(5):
|
||||||
|
conv = AppConversationInfo(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id=f'test_sandbox_{i}',
|
||||||
|
title=f'Conversation {i}',
|
||||||
|
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||||
|
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
|
||||||
|
public=True,
|
||||||
|
metrics=MetricsSnapshot(
|
||||||
|
accumulated_cost=0.0,
|
||||||
|
max_budget_per_task=10.0,
|
||||||
|
accumulated_token_usage=TokenUsage(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conversations.append(conv)
|
||||||
|
await app_conversation_service.save_app_conversation_info(conv)
|
||||||
|
|
||||||
|
# Get first page with limit 2
|
||||||
|
result = await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
|
||||||
|
)
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.next_page_id is not None
|
||||||
|
|
||||||
|
# Get next page
|
||||||
|
result2 = (
|
||||||
|
await shared_conversation_info_service.search_shared_conversation_info(
|
||||||
|
limit=2,
|
||||||
|
page_id=result.next_page_id,
|
||||||
|
sort_order=SharedConversationSortOrder.CREATED_AT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(result2.items) == 2
|
||||||
|
assert result2.next_page_id is not None
|
||||||
|
|
||||||
|
# Verify no overlap between pages
|
||||||
|
page1_ids = {item.id for item in result.items}
|
||||||
|
page2_ids = {item.id for item in result2.items}
|
||||||
|
assert page1_ids.isdisjoint(page2_ids)
|
||||||
@@ -0,0 +1,365 @@
|
|||||||
|
"""Tests for SharedEventService."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from server.sharing.filesystem_shared_event_service import (
|
||||||
|
SharedEventServiceImpl,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_conversation_info_service import (
|
||||||
|
SharedConversationInfoService,
|
||||||
|
)
|
||||||
|
from server.sharing.shared_conversation_models import SharedConversation
|
||||||
|
|
||||||
|
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||||
|
from openhands.app_server.event.event_service import EventService
|
||||||
|
from openhands.sdk.llm import MetricsSnapshot
|
||||||
|
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_shared_conversation_info_service():
|
||||||
|
"""Create a mock SharedConversationInfoService."""
|
||||||
|
return AsyncMock(spec=SharedConversationInfoService)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_event_service():
|
||||||
|
"""Create a mock EventService."""
|
||||||
|
return AsyncMock(spec=EventService)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
|
||||||
|
"""Create a SharedEventService for testing."""
|
||||||
|
return SharedEventServiceImpl(
|
||||||
|
shared_conversation_info_service=mock_shared_conversation_info_service,
|
||||||
|
event_service=mock_event_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_public_conversation():
|
||||||
|
"""Create a sample public conversation."""
|
||||||
|
return SharedConversation(
|
||||||
|
id=uuid4(),
|
||||||
|
created_by_user_id='test_user',
|
||||||
|
sandbox_id='test_sandbox',
|
||||||
|
title='Test Public Conversation',
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
metrics=MetricsSnapshot(
|
||||||
|
accumulated_cost=0.0,
|
||||||
|
max_budget_per_task=10.0,
|
||||||
|
accumulated_token_usage=TokenUsage(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_event():
|
||||||
|
"""Create a sample event."""
|
||||||
|
# For testing purposes, we'll just use a mock that the EventPage can accept
|
||||||
|
# The actual event creation is complex and not the focus of these tests
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedEventService:
|
||||||
|
"""Test cases for SharedEventService."""
|
||||||
|
|
||||||
|
async def test_get_shared_event_returns_event_for_public_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
sample_public_conversation,
|
||||||
|
sample_event,
|
||||||
|
):
|
||||||
|
"""Test that get_shared_event returns an event for a public conversation."""
|
||||||
|
conversation_id = sample_public_conversation.id
|
||||||
|
event_id = 'test_event_id'
|
||||||
|
|
||||||
|
# Mock the public conversation service to return a public conversation
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||||
|
|
||||||
|
# Mock the event service to return an event
|
||||||
|
mock_event_service.get_event.return_value = sample_event
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == sample_event
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
mock_event_service.get_event.assert_called_once_with(event_id)
|
||||||
|
|
||||||
|
async def test_get_shared_event_returns_none_for_private_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
):
|
||||||
|
"""Test that get_shared_event returns None for a private conversation."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
event_id = 'test_event_id'
|
||||||
|
|
||||||
|
# Mock the public conversation service to return None (private conversation)
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.get_shared_event(conversation_id, event_id)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result is None
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
# Event service should not be called
|
||||||
|
mock_event_service.get_event.assert_not_called()
|
||||||
|
|
||||||
|
async def test_search_shared_events_returns_events_for_public_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
sample_public_conversation,
|
||||||
|
sample_event,
|
||||||
|
):
|
||||||
|
"""Test that search_shared_events returns events for a public conversation."""
|
||||||
|
conversation_id = sample_public_conversation.id
|
||||||
|
|
||||||
|
# Mock the public conversation service to return a public conversation
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||||
|
|
||||||
|
# Mock the event service to return events
|
||||||
|
mock_event_page = EventPage(items=[], next_page_id=None)
|
||||||
|
mock_event_service.search_events.return_value = mock_event_page
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.search_shared_events(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
kind__eq='ActionEvent',
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == mock_event_page
|
||||||
|
assert len(result.items) == 0 # Empty list as we mocked
|
||||||
|
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
mock_event_service.search_events.assert_called_once_with(
|
||||||
|
conversation_id__eq=conversation_id,
|
||||||
|
kind__eq='ActionEvent',
|
||||||
|
timestamp__gte=None,
|
||||||
|
timestamp__lt=None,
|
||||||
|
sort_order=EventSortOrder.TIMESTAMP,
|
||||||
|
page_id=None,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_search_shared_events_returns_empty_for_private_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
):
|
||||||
|
"""Test that search_shared_events returns empty page for a private conversation."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
|
||||||
|
# Mock the public conversation service to return None (private conversation)
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.search_shared_events(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, EventPage)
|
||||||
|
assert len(result.items) == 0
|
||||||
|
assert result.next_page_id is None
|
||||||
|
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
# Event service should not be called
|
||||||
|
mock_event_service.search_events.assert_not_called()
|
||||||
|
|
||||||
|
async def test_count_shared_events_returns_count_for_public_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
sample_public_conversation,
|
||||||
|
):
|
||||||
|
"""Test that count_shared_events returns count for a public conversation."""
|
||||||
|
conversation_id = sample_public_conversation.id
|
||||||
|
|
||||||
|
# Mock the public conversation service to return a public conversation
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||||
|
|
||||||
|
# Mock the event service to return a count
|
||||||
|
mock_event_service.count_events.return_value = 5
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.count_shared_events(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
kind__eq='ActionEvent',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
mock_event_service.count_events.assert_called_once_with(
|
||||||
|
conversation_id__eq=conversation_id,
|
||||||
|
kind__eq='ActionEvent',
|
||||||
|
timestamp__gte=None,
|
||||||
|
timestamp__lt=None,
|
||||||
|
sort_order=EventSortOrder.TIMESTAMP,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_count_shared_events_returns_zero_for_private_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
):
|
||||||
|
"""Test that count_shared_events returns 0 for a private conversation."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
|
||||||
|
# Mock the public conversation service to return None (private conversation)
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.count_shared_events(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
# Event service should not be called
|
||||||
|
mock_event_service.count_events.assert_not_called()
|
||||||
|
|
||||||
|
async def test_batch_get_shared_events_returns_events_for_public_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
sample_public_conversation,
|
||||||
|
sample_event,
|
||||||
|
):
|
||||||
|
"""Test that batch_get_shared_events returns events for a public conversation."""
|
||||||
|
conversation_id = sample_public_conversation.id
|
||||||
|
event_ids = ['event1', 'event2']
|
||||||
|
|
||||||
|
# Mock the public conversation service to return a public conversation
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||||
|
|
||||||
|
# Mock the event service to return events
|
||||||
|
mock_event_service.get_event.side_effect = [sample_event, None]
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.batch_get_shared_events(
|
||||||
|
conversation_id, event_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] == sample_event
|
||||||
|
assert result[1] is None
|
||||||
|
|
||||||
|
# Verify that get_shared_conversation_info was called for each event
|
||||||
|
assert (
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||||
|
== 2
|
||||||
|
)
|
||||||
|
# Verify that get_event was called for each event
|
||||||
|
assert mock_event_service.get_event.call_count == 2
|
||||||
|
|
||||||
|
async def test_batch_get_shared_events_returns_none_for_private_conversation(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
):
|
||||||
|
"""Test that batch_get_shared_events returns None for a private conversation."""
|
||||||
|
conversation_id = uuid4()
|
||||||
|
event_ids = ['event1', 'event2']
|
||||||
|
|
||||||
|
# Mock the public conversation service to return None (private conversation)
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = await shared_event_service.batch_get_shared_events(
|
||||||
|
conversation_id, event_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] is None
|
||||||
|
assert result[1] is None
|
||||||
|
|
||||||
|
# Verify that get_shared_conversation_info was called for each event
|
||||||
|
assert (
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
|
||||||
|
== 2
|
||||||
|
)
|
||||||
|
# Event service should not be called
|
||||||
|
mock_event_service.get_event.assert_not_called()
|
||||||
|
|
||||||
|
async def test_search_shared_events_with_all_parameters(
|
||||||
|
self,
|
||||||
|
shared_event_service,
|
||||||
|
mock_shared_conversation_info_service,
|
||||||
|
mock_event_service,
|
||||||
|
sample_public_conversation,
|
||||||
|
):
|
||||||
|
"""Test search_shared_events with all parameters."""
|
||||||
|
conversation_id = sample_public_conversation.id
|
||||||
|
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
|
||||||
|
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
|
||||||
|
|
||||||
|
# Mock the public conversation service to return a public conversation
|
||||||
|
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||||
|
|
||||||
|
# Mock the event service to return events
|
||||||
|
mock_event_page = EventPage(items=[], next_page_id='next_page')
|
||||||
|
mock_event_service.search_events.return_value = mock_event_page
|
||||||
|
|
||||||
|
# Call the method with all parameters
|
||||||
|
result = await shared_event_service.search_shared_events(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
kind__eq='ObservationEvent',
|
||||||
|
timestamp__gte=timestamp_gte,
|
||||||
|
timestamp__lt=timestamp_lt,
|
||||||
|
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||||
|
page_id='current_page',
|
||||||
|
limit=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == mock_event_page
|
||||||
|
|
||||||
|
mock_event_service.search_events.assert_called_once_with(
|
||||||
|
conversation_id__eq=conversation_id,
|
||||||
|
kind__eq='ObservationEvent',
|
||||||
|
timestamp__gte=timestamp_gte,
|
||||||
|
timestamp__lt=timestamp_lt,
|
||||||
|
sort_order=EventSortOrder.TIMESTAMP_DESC,
|
||||||
|
page_id='current_page',
|
||||||
|
limit=50,
|
||||||
|
)
|
||||||
@@ -45,6 +45,8 @@ class AppConversationInfo(BaseModel):
|
|||||||
parent_conversation_id: OpenHandsUUID | None = None
|
parent_conversation_id: OpenHandsUUID | None = None
|
||||||
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
|
||||||
|
|
||||||
|
public: bool | None = None
|
||||||
|
|
||||||
created_at: datetime = Field(default_factory=utc_now)
|
created_at: datetime = Field(default_factory=utc_now)
|
||||||
updated_at: datetime = Field(default_factory=utc_now)
|
updated_at: datetime = Field(default_factory=utc_now)
|
||||||
|
|
||||||
@@ -114,6 +116,12 @@ class AppConversationStartRequest(BaseModel):
|
|||||||
parent_conversation_id: OpenHandsUUID | None = None
|
parent_conversation_id: OpenHandsUUID | None = None
|
||||||
agent_type: AgentType = Field(default=AgentType.DEFAULT)
|
agent_type: AgentType = Field(default=AgentType.DEFAULT)
|
||||||
|
|
||||||
|
public: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AppConversationUpdateRequest(BaseModel):
|
||||||
|
public: bool
|
||||||
|
|
||||||
|
|
||||||
class AppConversationStartTaskStatus(Enum):
|
class AppConversationStartTaskStatus(Enum):
|
||||||
WORKING = 'WORKING'
|
WORKING = 'WORKING'
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
|||||||
AppConversationStartTask,
|
AppConversationStartTask,
|
||||||
AppConversationStartTaskPage,
|
AppConversationStartTaskPage,
|
||||||
AppConversationStartTaskSortOrder,
|
AppConversationStartTaskSortOrder,
|
||||||
|
AppConversationUpdateRequest,
|
||||||
SkillResponse,
|
SkillResponse,
|
||||||
)
|
)
|
||||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||||
@@ -222,6 +223,22 @@ async def start_app_conversation(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch('/{conversation_id}')
|
||||||
|
async def update_app_conversation(
|
||||||
|
conversation_id: str,
|
||||||
|
update_request: AppConversationUpdateRequest,
|
||||||
|
app_conversation_service: AppConversationService = (
|
||||||
|
app_conversation_service_dependency
|
||||||
|
),
|
||||||
|
) -> AppConversation:
|
||||||
|
info = await app_conversation_service.update_app_conversation(
|
||||||
|
UUID(conversation_id), update_request
|
||||||
|
)
|
||||||
|
if info is None:
|
||||||
|
raise HTTPException(404, 'unknown_app_conversation')
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
@router.post('/stream-start')
|
@router.post('/stream-start')
|
||||||
async def stream_app_conversation_start(
|
async def stream_app_conversation_start(
|
||||||
request: AppConversationStartRequest,
|
request: AppConversationStartRequest,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
|||||||
AppConversationSortOrder,
|
AppConversationSortOrder,
|
||||||
AppConversationStartRequest,
|
AppConversationStartRequest,
|
||||||
AppConversationStartTask,
|
AppConversationStartTask,
|
||||||
|
AppConversationUpdateRequest,
|
||||||
)
|
)
|
||||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||||
from openhands.app_server.services.injector import Injector
|
from openhands.app_server.services.injector import Injector
|
||||||
@@ -98,6 +99,13 @@ class AppConversationService(ABC):
|
|||||||
"""Run the setup scripts for the project and yield status updates"""
|
"""Run the setup scripts for the project and yield status updates"""
|
||||||
yield task
|
yield task
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_app_conversation(
|
||||||
|
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||||
|
) -> AppConversation | None:
|
||||||
|
"""Update an app conversation and return it. Return None if the conversation
|
||||||
|
did not exist."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||||
"""Delete a V1 conversation and all its associated data.
|
"""Delete a V1 conversation and all its associated data.
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
|||||||
AppConversationStartRequest,
|
AppConversationStartRequest,
|
||||||
AppConversationStartTask,
|
AppConversationStartTask,
|
||||||
AppConversationStartTaskStatus,
|
AppConversationStartTaskStatus,
|
||||||
|
AppConversationUpdateRequest,
|
||||||
)
|
)
|
||||||
from openhands.app_server.app_conversation.app_conversation_service import (
|
from openhands.app_server.app_conversation.app_conversation_service import (
|
||||||
AppConversationService,
|
AppConversationService,
|
||||||
@@ -1049,6 +1050,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
|||||||
f'Successfully updated agent-server conversation {conversation_id} title to "{new_title}"'
|
f'Successfully updated agent-server conversation {conversation_id} title to "{new_title}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def update_app_conversation(
|
||||||
|
self, conversation_id: UUID, request: AppConversationUpdateRequest
|
||||||
|
) -> AppConversation | None:
|
||||||
|
"""Update an app conversation and return it. Return None if the conversation
|
||||||
|
did not exist."""
|
||||||
|
info = await self.app_conversation_info_service.get_app_conversation_info(
|
||||||
|
conversation_id
|
||||||
|
)
|
||||||
|
if info is None:
|
||||||
|
return None
|
||||||
|
for field_name in request.model_fields:
|
||||||
|
value = getattr(request, field_name)
|
||||||
|
setattr(info, field_name, value)
|
||||||
|
info = await self.app_conversation_info_service.save_app_conversation_info(info)
|
||||||
|
conversations = await self._build_app_conversations([info])
|
||||||
|
return conversations[0]
|
||||||
|
|
||||||
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
|
||||||
"""Delete a V1 conversation and all its associated data.
|
"""Delete a V1 conversation and all its associated data.
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,17 @@ from typing import AsyncGenerator
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy import Column, DateTime, Float, Integer, Select, String, func, select
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
Integer,
|
||||||
|
Select,
|
||||||
|
String,
|
||||||
|
func,
|
||||||
|
select,
|
||||||
|
)
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from openhands.agent_server.utils import utc_now
|
from openhands.agent_server.utils import utc_now
|
||||||
@@ -91,6 +101,7 @@ class StoredConversationMetadata(Base): # type: ignore
|
|||||||
conversation_version = Column(String, nullable=False, default='V0', index=True)
|
conversation_version = Column(String, nullable=False, default='V0', index=True)
|
||||||
sandbox_id = Column(String, nullable=True, index=True)
|
sandbox_id = Column(String, nullable=True, index=True)
|
||||||
parent_conversation_id = Column(String, nullable=True, index=True)
|
parent_conversation_id = Column(String, nullable=True, index=True)
|
||||||
|
public = Column(Boolean, nullable=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -350,6 +361,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
|||||||
if info.parent_conversation_id
|
if info.parent_conversation_id
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
public=info.public,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.db_session.merge(stored)
|
await self.db_session.merge(stored)
|
||||||
@@ -541,6 +553,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
sub_conversation_ids=sub_conversation_ids or [],
|
sub_conversation_ids=sub_conversation_ids or [],
|
||||||
|
public=stored.public,
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
)
|
)
|
||||||
|
|||||||
41
openhands/app_server/app_lifespan/alembic/versions/004.py
Normal file
41
openhands/app_server/app_lifespan/alembic/versions/004.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""add public column to conversation_metadata
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2025-01-27 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '004'
|
||||||
|
down_revision: Union[str, None] = '003'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('public', sa.Boolean(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f('ix_conversation_metadata_public'),
|
||||||
|
'conversation_metadata',
|
||||||
|
['public'],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
op.drop_index(
|
||||||
|
op.f('ix_conversation_metadata_public'),
|
||||||
|
table_name='conversation_metadata',
|
||||||
|
)
|
||||||
|
op.drop_column('conversation_metadata', 'public')
|
||||||
@@ -22,7 +22,7 @@ event_service_dependency = depends_event_service()
|
|||||||
@router.get('/search')
|
@router.get('/search')
|
||||||
async def search_events(
|
async def search_events(
|
||||||
conversation_id__eq: Annotated[
|
conversation_id__eq: Annotated[
|
||||||
UUID | None,
|
str | None,
|
||||||
Query(title='Optional filter by conversation ID'),
|
Query(title='Optional filter by conversation ID'),
|
||||||
] = None,
|
] = None,
|
||||||
kind__eq: Annotated[
|
kind__eq: Annotated[
|
||||||
@@ -55,7 +55,7 @@ async def search_events(
|
|||||||
assert limit > 0
|
assert limit > 0
|
||||||
assert limit <= 100
|
assert limit <= 100
|
||||||
return await event_service.search_events(
|
return await event_service.search_events(
|
||||||
conversation_id__eq=conversation_id__eq,
|
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
|
||||||
kind__eq=kind__eq,
|
kind__eq=kind__eq,
|
||||||
timestamp__gte=timestamp__gte,
|
timestamp__gte=timestamp__gte,
|
||||||
timestamp__lt=timestamp__lt,
|
timestamp__lt=timestamp__lt,
|
||||||
@@ -68,7 +68,7 @@ async def search_events(
|
|||||||
@router.get('/count')
|
@router.get('/count')
|
||||||
async def count_events(
|
async def count_events(
|
||||||
conversation_id__eq: Annotated[
|
conversation_id__eq: Annotated[
|
||||||
UUID | None,
|
str | None,
|
||||||
Query(title='Optional filter by conversation ID'),
|
Query(title='Optional filter by conversation ID'),
|
||||||
] = None,
|
] = None,
|
||||||
kind__eq: Annotated[
|
kind__eq: Annotated[
|
||||||
@@ -91,7 +91,7 @@ async def count_events(
|
|||||||
) -> int:
|
) -> int:
|
||||||
"""Count events matching the given filters."""
|
"""Count events matching the given filters."""
|
||||||
return await event_service.count_events(
|
return await event_service.count_events(
|
||||||
conversation_id__eq=conversation_id__eq,
|
conversation_id__eq=UUID(conversation_id__eq) if conversation_id__eq else None,
|
||||||
kind__eq=kind__eq,
|
kind__eq=kind__eq,
|
||||||
timestamp__gte=timestamp__gte,
|
timestamp__gte=timestamp__gte,
|
||||||
timestamp__lt=timestamp__lt,
|
timestamp__lt=timestamp__lt,
|
||||||
|
|||||||
@@ -1,32 +1,27 @@
|
|||||||
"""Filesystem-based EventService implementation."""
|
"""Filesystem-based EventService implementation."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import glob
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
|
||||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||||
AppConversationInfoService,
|
AppConversationInfoService,
|
||||||
)
|
)
|
||||||
from openhands.app_server.errors import OpenHandsError
|
from openhands.app_server.errors import OpenHandsError
|
||||||
from openhands.app_server.event.event_service import EventService, EventServiceInjector
|
from openhands.app_server.event.event_service import EventService, EventServiceInjector
|
||||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
from openhands.app_server.event.filesystem_event_service_base import (
|
||||||
|
FilesystemEventServiceBase,
|
||||||
|
)
|
||||||
from openhands.app_server.services.injector import InjectorState
|
from openhands.app_server.services.injector import InjectorState
|
||||||
from openhands.sdk import Event
|
from openhands.sdk import Event
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilesystemEventService(EventService):
|
class FilesystemEventService(FilesystemEventServiceBase, EventService):
|
||||||
"""Filesystem-based implementation of EventService.
|
"""Filesystem-based implementation of EventService.
|
||||||
|
|
||||||
Events are stored in files with the naming format:
|
Events are stored in files with the naming format:
|
||||||
@@ -47,25 +42,6 @@ class FilesystemEventService(EventService):
|
|||||||
events_path.mkdir(parents=True, exist_ok=True)
|
events_path.mkdir(parents=True, exist_ok=True)
|
||||||
return events_path
|
return events_path
|
||||||
|
|
||||||
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
|
|
||||||
"""Convert timestamp to YYYYMMDDHHMMSS format."""
|
|
||||||
if isinstance(timestamp, str):
|
|
||||||
# Parse ISO format timestamp string
|
|
||||||
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
|
||||||
return dt.strftime('%Y%m%d%H%M%S')
|
|
||||||
return timestamp.strftime('%Y%m%d%H%M%S')
|
|
||||||
|
|
||||||
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
|
|
||||||
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
|
|
||||||
timestamp_str = self._timestamp_to_str(event.timestamp)
|
|
||||||
kind = event.__class__.__name__
|
|
||||||
# Handle both UUID objects and string UUIDs
|
|
||||||
if isinstance(event.id, str):
|
|
||||||
id_hex = event.id.replace('-', '')
|
|
||||||
else:
|
|
||||||
id_hex = event.id.hex
|
|
||||||
return f'{timestamp_str}_{kind}_{id_hex}'
|
|
||||||
|
|
||||||
def _save_event_to_file(self, conversation_id: UUID, event: Event) -> None:
|
def _save_event_to_file(self, conversation_id: UUID, event: Event) -> None:
|
||||||
"""Save an event to a file."""
|
"""Save an event to a file."""
|
||||||
events_path = self._ensure_events_dir(conversation_id)
|
events_path = self._ensure_events_dir(conversation_id)
|
||||||
@@ -77,60 +53,17 @@ class FilesystemEventService(EventService):
|
|||||||
data = event.model_dump(mode='json')
|
data = event.model_dump(mode='json')
|
||||||
f.write(json.dumps(data, indent=2))
|
f.write(json.dumps(data, indent=2))
|
||||||
|
|
||||||
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
|
async def save_event(self, conversation_id: UUID, event: Event):
|
||||||
events = []
|
"""Save an event. Internal method intended not be part of the REST api."""
|
||||||
for file_path in file_paths:
|
conversation = (
|
||||||
event = self._load_event_from_file(file_path)
|
await self.app_conversation_info_service.get_app_conversation_info(
|
||||||
if event is not None:
|
conversation_id
|
||||||
events.append(event)
|
)
|
||||||
return events
|
)
|
||||||
|
if not conversation:
|
||||||
def _load_event_from_file(self, filepath: Path) -> Event | None:
|
# This is either an illegal state or somebody is trying to hack
|
||||||
"""Load an event from a file."""
|
raise OpenHandsError('No such conversation: {conversaiont_id}')
|
||||||
try:
|
self._save_event_to_file(conversation_id, event)
|
||||||
json_data = filepath.read_text()
|
|
||||||
return Event.model_validate_json(json_data)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_event_files_by_pattern(
|
|
||||||
self, pattern: str, conversation_id: UUID | None = None
|
|
||||||
) -> list[Path]:
|
|
||||||
"""Get event files matching a glob pattern, sorted by timestamp."""
|
|
||||||
if conversation_id:
|
|
||||||
search_path = self.events_dir / str(conversation_id) / pattern
|
|
||||||
else:
|
|
||||||
search_path = self.events_dir / '*' / pattern
|
|
||||||
|
|
||||||
files = glob.glob(str(search_path))
|
|
||||||
return sorted([Path(f) for f in files])
|
|
||||||
|
|
||||||
def _parse_filename(self, filename: str) -> dict[str, str] | None:
|
|
||||||
"""Parse filename to extract timestamp, kind, and event_id."""
|
|
||||||
try:
|
|
||||||
parts = filename.split('_')
|
|
||||||
if len(parts) >= 3:
|
|
||||||
timestamp_str = parts[0]
|
|
||||||
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
|
|
||||||
event_id = parts[-1]
|
|
||||||
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_conversation_id(self, file: Path) -> UUID | None:
|
|
||||||
try:
|
|
||||||
return UUID(file.parent.name)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
|
|
||||||
result = set()
|
|
||||||
for file in files:
|
|
||||||
conversation_id = self._get_conversation_id(file)
|
|
||||||
if conversation_id:
|
|
||||||
result.add(conversation_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
|
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
|
||||||
conversation_ids = list(self._get_conversation_ids(files))
|
conversation_ids = list(self._get_conversation_ids(files))
|
||||||
@@ -150,161 +83,6 @@ class FilesystemEventService(EventService):
|
|||||||
]
|
]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _filter_files_by_criteria(
|
|
||||||
self,
|
|
||||||
files: list[Path],
|
|
||||||
conversation_id__eq: UUID | None = None,
|
|
||||||
kind__eq: EventKind | None = None,
|
|
||||||
timestamp__gte: datetime | None = None,
|
|
||||||
timestamp__lt: datetime | None = None,
|
|
||||||
) -> list[Path]:
|
|
||||||
"""Filter files based on search criteria."""
|
|
||||||
filtered_files = []
|
|
||||||
|
|
||||||
for file_path in files:
|
|
||||||
# Check conversation_id filter
|
|
||||||
if conversation_id__eq:
|
|
||||||
if str(conversation_id__eq) not in str(file_path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Parse filename for additional filtering
|
|
||||||
filename_info = self._parse_filename(file_path.name)
|
|
||||||
if not filename_info:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check kind filter
|
|
||||||
if kind__eq and filename_info['kind'] != kind__eq:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check timestamp filters
|
|
||||||
if timestamp__gte or timestamp__lt:
|
|
||||||
try:
|
|
||||||
file_timestamp = datetime.strptime(
|
|
||||||
filename_info['timestamp'], '%Y%m%d%H%M%S'
|
|
||||||
)
|
|
||||||
if timestamp__gte and file_timestamp < timestamp__gte:
|
|
||||||
continue
|
|
||||||
if timestamp__lt and file_timestamp >= timestamp__lt:
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
filtered_files.append(file_path)
|
|
||||||
|
|
||||||
return filtered_files
|
|
||||||
|
|
||||||
async def get_event(self, event_id: str) -> Event | None:
|
|
||||||
"""Get the event with the given id, or None if not found."""
|
|
||||||
# Convert event_id to hex format (remove dashes) for filename matching
|
|
||||||
if isinstance(event_id, str) and '-' in event_id:
|
|
||||||
id_hex = event_id.replace('-', '')
|
|
||||||
else:
|
|
||||||
id_hex = event_id
|
|
||||||
|
|
||||||
# Use glob pattern to find files ending with the event_id
|
|
||||||
pattern = f'*_{id_hex}'
|
|
||||||
files = self._get_event_files_by_pattern(pattern)
|
|
||||||
|
|
||||||
if not files:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# If there is no access to the conversation do not return the event
|
|
||||||
file = files[0]
|
|
||||||
conversation_id = self._get_conversation_id(file)
|
|
||||||
if not conversation_id:
|
|
||||||
return None
|
|
||||||
conversation = (
|
|
||||||
await self.app_conversation_info_service.get_app_conversation_info(
|
|
||||||
conversation_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not conversation:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Load and return the first matching event
|
|
||||||
return self._load_event_from_file(file)
|
|
||||||
|
|
||||||
async def search_events(
|
|
||||||
self,
|
|
||||||
conversation_id__eq: UUID | None = None,
|
|
||||||
kind__eq: EventKind | None = None,
|
|
||||||
timestamp__gte: datetime | None = None,
|
|
||||||
timestamp__lt: datetime | None = None,
|
|
||||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
|
||||||
page_id: str | None = None,
|
|
||||||
limit: int = 100,
|
|
||||||
) -> EventPage:
|
|
||||||
"""Search for events matching the given filters."""
|
|
||||||
# Build the search pattern
|
|
||||||
pattern = '*'
|
|
||||||
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
|
||||||
|
|
||||||
files = await self._filter_files_by_conversation(files)
|
|
||||||
|
|
||||||
files = self._filter_files_by_criteria(
|
|
||||||
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
|
||||||
)
|
|
||||||
|
|
||||||
files.sort(
|
|
||||||
key=lambda f: f.name,
|
|
||||||
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle pagination
|
|
||||||
start_index = 0
|
|
||||||
if page_id:
|
|
||||||
for i, file_path in enumerate(files):
|
|
||||||
if file_path.name == page_id:
|
|
||||||
start_index = i + 1
|
|
||||||
break
|
|
||||||
|
|
||||||
# Collect items for this page
|
|
||||||
page_files = files[start_index : start_index + limit]
|
|
||||||
next_page_id = None
|
|
||||||
if start_index + limit < len(files):
|
|
||||||
next_page_id = files[start_index + limit].name
|
|
||||||
|
|
||||||
# Load all events from files in a background thread.
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
page_events = await loop.run_in_executor(
|
|
||||||
None, self._load_events_from_files, page_files
|
|
||||||
)
|
|
||||||
|
|
||||||
return EventPage(items=page_events, next_page_id=next_page_id)
|
|
||||||
|
|
||||||
async def count_events(
|
|
||||||
self,
|
|
||||||
conversation_id__eq: UUID | None = None,
|
|
||||||
kind__eq: EventKind | None = None,
|
|
||||||
timestamp__gte: datetime | None = None,
|
|
||||||
timestamp__lt: datetime | None = None,
|
|
||||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
|
||||||
) -> int:
|
|
||||||
"""Count events matching the given filters."""
|
|
||||||
# Build the search pattern
|
|
||||||
pattern = '*'
|
|
||||||
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
|
||||||
|
|
||||||
files = await self._filter_files_by_conversation(files)
|
|
||||||
|
|
||||||
files = self._filter_files_by_criteria(
|
|
||||||
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(files)
|
|
||||||
|
|
||||||
async def save_event(self, conversation_id: UUID, event: Event):
|
|
||||||
"""Save an event. Internal method intended not be part of the REST api."""
|
|
||||||
conversation = (
|
|
||||||
await self.app_conversation_info_service.get_app_conversation_info(
|
|
||||||
conversation_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not conversation:
|
|
||||||
# This is either an illegal state or somebody is trying to hack
|
|
||||||
raise OpenHandsError('No such conversation: {conversaiont_id}')
|
|
||||||
self._save_event_to_file(conversation_id, event)
|
|
||||||
|
|
||||||
|
|
||||||
class FilesystemEventServiceInjector(EventServiceInjector):
|
class FilesystemEventServiceInjector(EventServiceInjector):
|
||||||
async def inject(
|
async def inject(
|
||||||
|
|||||||
224
openhands/app_server/event/filesystem_event_service_base.py
Normal file
224
openhands/app_server/event/filesystem_event_service_base.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
import asyncio
|
||||||
|
import glob
|
||||||
|
from abc import abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||||
|
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||||
|
from openhands.sdk import Event
|
||||||
|
|
||||||
|
|
||||||
|
class FilesystemEventServiceBase:
|
||||||
|
events_dir: Path
|
||||||
|
|
||||||
|
async def get_event(self, event_id: str) -> Event | None:
|
||||||
|
"""Get the event with the given id, or None if not found."""
|
||||||
|
# Convert event_id to hex format (remove dashes) for filename matching
|
||||||
|
if isinstance(event_id, str) and '-' in event_id:
|
||||||
|
id_hex = event_id.replace('-', '')
|
||||||
|
else:
|
||||||
|
id_hex = event_id
|
||||||
|
|
||||||
|
# Use glob pattern to find files ending with the event_id
|
||||||
|
pattern = f'*_{id_hex}'
|
||||||
|
files = self._get_event_files_by_pattern(pattern)
|
||||||
|
|
||||||
|
files = await self._filter_files_by_conversation(files)
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load and return the first matching event
|
||||||
|
return self._load_event_from_file(files[0])
|
||||||
|
|
||||||
|
async def search_events(
|
||||||
|
self,
|
||||||
|
conversation_id__eq: UUID | None = None,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||||
|
page_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> EventPage:
|
||||||
|
"""Search for events matching the given filters."""
|
||||||
|
# Build the search pattern
|
||||||
|
pattern = '*'
|
||||||
|
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
||||||
|
|
||||||
|
files = await self._filter_files_by_conversation(files)
|
||||||
|
|
||||||
|
files = self._filter_files_by_criteria(
|
||||||
|
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
||||||
|
)
|
||||||
|
|
||||||
|
files.sort(
|
||||||
|
key=lambda f: f.name,
|
||||||
|
reverse=(sort_order == EventSortOrder.TIMESTAMP_DESC),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle pagination
|
||||||
|
start_index = 0
|
||||||
|
if page_id:
|
||||||
|
for i, file_path in enumerate(files):
|
||||||
|
if file_path.name == page_id:
|
||||||
|
start_index = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# Collect items for this page
|
||||||
|
page_files = files[start_index : start_index + limit]
|
||||||
|
next_page_id = None
|
||||||
|
if start_index + limit < len(files):
|
||||||
|
next_page_id = files[start_index + limit].name
|
||||||
|
|
||||||
|
# Load all events from files in a background thread.
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
page_events = await loop.run_in_executor(
|
||||||
|
None, self._load_events_from_files, page_files
|
||||||
|
)
|
||||||
|
|
||||||
|
return EventPage(items=page_events, next_page_id=next_page_id)
|
||||||
|
|
||||||
|
async def count_events(
|
||||||
|
self,
|
||||||
|
conversation_id__eq: UUID | None = None,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||||
|
) -> int:
|
||||||
|
"""Count events matching the given filters."""
|
||||||
|
# Build the search pattern
|
||||||
|
pattern = '*'
|
||||||
|
files = self._get_event_files_by_pattern(pattern, conversation_id__eq)
|
||||||
|
|
||||||
|
files = await self._filter_files_by_conversation(files)
|
||||||
|
|
||||||
|
files = self._filter_files_by_criteria(
|
||||||
|
files, conversation_id__eq, kind__eq, timestamp__gte, timestamp__lt
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(files)
|
||||||
|
|
||||||
|
def _get_event_filename(self, conversation_id: UUID, event: Event) -> str:
|
||||||
|
"""Generate filename using YYYYMMDDHHMMSS_kind_id.hex format."""
|
||||||
|
timestamp_str = self._timestamp_to_str(event.timestamp)
|
||||||
|
kind = event.__class__.__name__
|
||||||
|
# Handle both UUID objects and string UUIDs
|
||||||
|
if isinstance(event.id, str):
|
||||||
|
id_hex = event.id.replace('-', '')
|
||||||
|
else:
|
||||||
|
id_hex = event.id.hex
|
||||||
|
return f'{timestamp_str}_{kind}_{id_hex}'
|
||||||
|
|
||||||
|
def _timestamp_to_str(self, timestamp: datetime | str) -> str:
|
||||||
|
"""Convert timestamp to YYYYMMDDHHMMSS format."""
|
||||||
|
if isinstance(timestamp, str):
|
||||||
|
# Parse ISO format timestamp string
|
||||||
|
dt = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||||
|
return dt.strftime('%Y%m%d%H%M%S')
|
||||||
|
return timestamp.strftime('%Y%m%d%H%M%S')
|
||||||
|
|
||||||
|
def _load_events_from_files(self, file_paths: list[Path]) -> list[Event]:
|
||||||
|
events = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
event = self._load_event_from_file(file_path)
|
||||||
|
if event is not None:
|
||||||
|
events.append(event)
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _load_event_from_file(self, filepath: Path) -> Event | None:
|
||||||
|
"""Load an event from a file."""
|
||||||
|
try:
|
||||||
|
json_data = filepath.read_text()
|
||||||
|
return Event.model_validate_json(json_data)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_event_files_by_pattern(
|
||||||
|
self, pattern: str, conversation_id: UUID | None = None
|
||||||
|
) -> list[Path]:
|
||||||
|
"""Get event files matching a glob pattern, sorted by timestamp."""
|
||||||
|
if conversation_id:
|
||||||
|
search_path = self.events_dir / str(conversation_id) / pattern
|
||||||
|
else:
|
||||||
|
search_path = self.events_dir / '*' / pattern
|
||||||
|
|
||||||
|
files = glob.glob(str(search_path))
|
||||||
|
return sorted([Path(f) for f in files])
|
||||||
|
|
||||||
|
def _parse_filename(self, filename: str) -> dict[str, str] | None:
|
||||||
|
"""Parse filename to extract timestamp, kind, and event_id."""
|
||||||
|
try:
|
||||||
|
parts = filename.split('_')
|
||||||
|
if len(parts) >= 3:
|
||||||
|
timestamp_str = parts[0]
|
||||||
|
kind = '_'.join(parts[1:-1]) # Handle kinds with underscores
|
||||||
|
event_id = parts[-1]
|
||||||
|
return {'timestamp': timestamp_str, 'kind': kind, 'event_id': event_id}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_conversation_id(self, file: Path) -> UUID | None:
|
||||||
|
try:
|
||||||
|
return UUID(file.parent.name)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_conversation_ids(self, files: list[Path]) -> set[UUID]:
|
||||||
|
result = set()
|
||||||
|
for file in files:
|
||||||
|
conversation_id = self._get_conversation_id(file)
|
||||||
|
if conversation_id:
|
||||||
|
result.add(conversation_id)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _filter_files_by_conversation(self, files: list[Path]) -> list[Path]:
|
||||||
|
"""Filter files by conversation."""
|
||||||
|
|
||||||
|
def _filter_files_by_criteria(
|
||||||
|
self,
|
||||||
|
files: list[Path],
|
||||||
|
conversation_id__eq: UUID | None = None,
|
||||||
|
kind__eq: EventKind | None = None,
|
||||||
|
timestamp__gte: datetime | None = None,
|
||||||
|
timestamp__lt: datetime | None = None,
|
||||||
|
) -> list[Path]:
|
||||||
|
"""Filter files based on search criteria."""
|
||||||
|
filtered_files = []
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
# Check conversation_id filter
|
||||||
|
if conversation_id__eq:
|
||||||
|
if str(conversation_id__eq) not in str(file_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse filename for additional filtering
|
||||||
|
filename_info = self._parse_filename(file_path.name)
|
||||||
|
if not filename_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check kind filter
|
||||||
|
if kind__eq and filename_info['kind'] != kind__eq:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check timestamp filters
|
||||||
|
if timestamp__gte or timestamp__lt:
|
||||||
|
try:
|
||||||
|
file_timestamp = datetime.strptime(
|
||||||
|
filename_info['timestamp'], '%Y%m%d%H%M%S'
|
||||||
|
)
|
||||||
|
if timestamp__gte and file_timestamp < timestamp__gte:
|
||||||
|
continue
|
||||||
|
if timestamp__lt and file_timestamp >= timestamp__lt:
|
||||||
|
continue
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_files.append(file_path)
|
||||||
|
|
||||||
|
return filtered_files
|
||||||
@@ -29,3 +29,4 @@ class ConversationInfo:
|
|||||||
pr_number: list[int] = field(default_factory=list)
|
pr_number: list[int] = field(default_factory=list)
|
||||||
conversation_version: str = 'V0'
|
conversation_version: str = 'V0'
|
||||||
sub_conversation_ids: list[str] = field(default_factory=list)
|
sub_conversation_ids: list[str] = field(default_factory=list)
|
||||||
|
public: bool | None = None
|
||||||
|
|||||||
@@ -1501,4 +1501,5 @@ def _to_conversation_info(app_conversation: AppConversation) -> ConversationInfo
|
|||||||
sub_conversation_ids=[
|
sub_conversation_ids=[
|
||||||
sub_id.hex for sub_id in app_conversation.sub_conversation_ids
|
sub_id.hex for sub_id in app_conversation.sub_conversation_ids
|
||||||
],
|
],
|
||||||
|
public=app_conversation.public,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,3 +39,4 @@ class ConversationMetadata:
|
|||||||
# V1 compatibility
|
# V1 compatibility
|
||||||
sandbox_id: str | None = None
|
sandbox_id: str | None = None
|
||||||
conversation_version: str | None = None
|
conversation_version: str | None = None
|
||||||
|
public: bool | None = None
|
||||||
|
|||||||
15
poetry.lock
generated
15
poetry.lock
generated
@@ -12707,18 +12707,19 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-asyncio"
|
name = "pytest-asyncio"
|
||||||
version = "1.1.0"
|
version = "1.3.0"
|
||||||
description = "Pytest support for asyncio"
|
description = "Pytest support for asyncio"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.10"
|
||||||
groups = ["test"]
|
groups = ["dev", "test"]
|
||||||
files = [
|
files = [
|
||||||
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
|
{file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"},
|
||||||
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
|
{file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
pytest = ">=8.2,<9"
|
pytest = ">=8.2,<10"
|
||||||
|
typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||||
@@ -16823,4 +16824,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = "^3.12,<3.14"
|
python-versions = "^3.12,<3.14"
|
||||||
content-hash = "9764f3b69ec8ed35feebd78a826bbc6bfa4ac6d5b56bc999be8bc738b644e538"
|
content-hash = "e24ceb52bccd0c80f52c408215ccf007475eb69e10b895053ea49c7e3e4be3b8"
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ pre-commit = "4.2.0"
|
|||||||
build = "*"
|
build = "*"
|
||||||
types-setuptools = "*"
|
types-setuptools = "*"
|
||||||
pytest = "^8.4.0"
|
pytest = "^8.4.0"
|
||||||
|
pytest-asyncio = "^1.3.0"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
|
|||||||
Reference in New Issue
Block a user