Compare commits

...

1 Commits

Author SHA1 Message Date
Joe Laverty
32f730153e bugfix(enterprise): Remove shared conversation enumerator endpoints (#13976) 2026-04-16 14:12:41 -04:00
6 changed files with 6 additions and 635 deletions

View File

@@ -1,12 +1,9 @@
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from openhands.app_server.services.injector import Injector
@@ -16,32 +13,6 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
class SharedConversationInfoService(ABC):
"""Service for accessing shared conversation info without user restrictions."""
@abstractmethod
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
@abstractmethod
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations."""
@abstractmethod
async def get_shared_conversation_info(
self, conversation_id: UUID

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from enum import Enum
# Simplified imports to avoid dependency chain issues
# from openhands.integrations.service_types import ProviderType
@@ -40,17 +39,3 @@ class SharedConversation(BaseModel):
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
class SharedConversationSortOrder(Enum):
CREATED_AT = 'CREATED_AT'
CREATED_AT_DESC = 'CREATED_AT_DESC'
UPDATED_AT = 'UPDATED_AT'
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
TITLE = 'TITLE'
TITLE_DESC = 'TITLE_DESC'
class SharedConversationPage(BaseModel):
items: list[SharedConversation]
next_page_id: str | None = None

View File

@@ -1,6 +1,5 @@
"""Shared Conversation router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
@@ -10,8 +9,6 @@ from server.sharing.shared_conversation_info_service import (
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoServiceInjector,
@@ -22,101 +19,13 @@ shared_conversation_info_service_dependency = Depends(
SQLSharedConversationInfoServiceInjector().depends
)
# Read methods
@router.get('/search')
async def search_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
sort_order: Annotated[
SharedConversationSortOrder,
Query(title='Sort order for results'),
] = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(
title='The max number of results in the page',
gt=0,
le=100,
),
] = 100,
include_sub_conversations: Annotated[
bool,
Query(
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
),
] = False,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> SharedConversationPage:
"""Search / List shared conversations."""
return await shared_conversation_service.search_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
include_sub_conversations=include_sub_conversations,
)
@router.get('/count')
async def count_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> int:
"""Count shared conversations matching the given filters."""
return await shared_conversation_service.count_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
#
# These endpoints are unauthenticated. Only batch lookup by known IDs is
# exposed publicly so that share links of the form
# /shared/conversations/<id> can be viewed without auth. Listing or
# enumerating shared conversations is intentionally not exposed.
@router.get('')

View File

@@ -21,8 +21,6 @@ from server.sharing.shared_conversation_info_service import (
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -45,113 +43,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
db_session: AsyncSession
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select_with_saas_metadata()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == SharedConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == SharedConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
for stored, saas_metadata in rows
]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return SharedConversationPage(items=items, next_page_id=next_page_id)
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations matching the given filters."""
from sqlalchemy import func
query = select(func.count(StoredConversationMetadata.conversation_id))
# Only include shared conversations
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
return result.scalar() or 0
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
@@ -169,15 +60,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
stored, saas_metadata = row
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
def _public_select(self):
"""Create a select query that only returns public conversations."""
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
# Only include conversations marked as public
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _public_select_with_saas_metadata(self):
"""Create a select query that returns public conversations with SAAS metadata.
@@ -197,41 +79,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
)
return query
def _apply_filters(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply common filters to a query."""
if title__contains is not None:
query = query.where(
StoredConversationMetadata.title.contains(title__contains)
)
if created_at__gte is not None:
query = query.where(
StoredConversationMetadata.created_at >= created_at__gte
)
if created_at__lt is not None:
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
query = query.where(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
query = query.where(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
return query
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,

View File

@@ -5,8 +5,6 @@ from uuid import uuid4
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
@@ -32,42 +30,6 @@ def test_public_conversation_creation():
assert conversation.sandbox_id == 'test_sandbox'
def test_public_conversation_page_creation():
"""Test that SharedConversationPage can be created."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
page = SharedConversationPage(
items=[conversation],
next_page_id='next_page',
)
assert len(page.items) == 1
assert page.items[0].id == conversation_id
assert page.next_page_id == 'next_page'
def test_public_conversation_sort_order_enum():
"""Test that SharedConversationSortOrder enum has expected values."""
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'TITLE')
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
def test_public_conversation_optional_fields():
"""Test that SharedConversation works with optional fields."""
conversation_id = uuid4()

View File

@@ -5,9 +5,6 @@ from typing import AsyncGenerator
from uuid import UUID, uuid4
import pytest
from server.sharing.shared_conversation_models import (
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
@@ -201,157 +198,6 @@ class TestSharedConversationInfoService:
)
assert result is None
@pytest.mark.asyncio
async def test_search_shared_conversation_info_returns_only_public_conversations(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test that search only returns public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Search for all conversations
result = (
await shared_conversation_info_service.search_shared_conversation_info()
)
# Should only return the public conversation
assert len(result.items) == 1
assert result.items[0].id == sample_conversation_info.id
assert result.items[0].title == sample_conversation_info.title
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_title_filter(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test searching with title filter."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Search with matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='Test'
)
assert len(result.items) == 1
# Search with non-matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='NonExistent'
)
assert len(result.items) == 0
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_sort_order(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test searching with different sort orders."""
# Create multiple public conversations with different titles and timestamps
conv1 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_1',
title='A First Conversation',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv2 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_2',
title='B Second Conversation',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_conversation_service.save_app_conversation_info(conv1)
await app_conversation_service.save_app_conversation_info(conv2)
# Test sort by title ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE
)
assert len(result.items) == 2
assert result.items[0].title == 'A First Conversation'
assert result.items[1].title == 'B Second Conversation'
# Test sort by title descending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE_DESC
)
assert len(result.items) == 2
assert result.items[0].title == 'B Second Conversation'
assert result.items[1].title == 'A First Conversation'
# Test sort by created_at ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.items[0].id == conv1.id
assert result.items[1].id == conv2.id
# Test sort by created_at descending (default)
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
)
assert len(result.items) == 2
assert result.items[0].id == conv2.id
assert result.items[1].id == conv1.id
@pytest.mark.asyncio
async def test_count_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test counting public conversations."""
# Initially should be 0
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 0
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
# Create a private conversation - count should remain 1
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
@pytest.mark.asyncio
async def test_batch_get_shared_conversation_info(
self,
@@ -382,56 +228,6 @@ class TestSharedConversationInfoService:
assert result[0].id == sample_conversation_info.id
assert result[1] is None
@pytest.mark.asyncio
async def test_search_with_pagination(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test search with pagination."""
# Create multiple public conversations
conversations = []
for i in range(5):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id=f'test_sandbox_{i}',
title=f'Conversation {i}',
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conversations.append(conv)
await app_conversation_service.save_app_conversation_info(conv)
# Get first page with limit 2
result = await shared_conversation_info_service.search_shared_conversation_info(
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.next_page_id is not None
# Get next page
result2 = (
await shared_conversation_info_service.search_shared_conversation_info(
limit=2,
page_id=result.next_page_id,
sort_order=SharedConversationSortOrder.CREATED_AT,
)
)
assert len(result2.items) == 2
assert result2.next_page_id is not None
# Verify no overlap between pages
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)
class TestSharedConversationInfoServiceWithSaasMetadata:
"""Test cases for SharedConversationInfoService with SAAS metadata.
@@ -552,42 +348,6 @@ class TestSharedConversationInfoServiceWithSaasMetadata:
assert result is not None
assert result.created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_search_shared_conversations_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that search_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox_search',
title='Searchable Public Conversation',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.search_shared_conversation_info()
# Assert
assert len(result.items) == 1
assert result.items[0].created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_batch_get_shared_conversations_returns_user_id_from_saas_metadata(
self,
@@ -626,66 +386,3 @@ class TestSharedConversationInfoServiceWithSaasMetadata:
assert len(result) == 1
assert result[0] is not None
assert result[0].created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_mixed_conversations_with_and_without_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test handling of conversations where some have SAAS metadata and some don't."""
# Arrange
conv_with_saas_id = uuid4()
conv_without_saas_id = uuid4()
conv_with_saas = AppConversationInfo(
id=conv_with_saas_id,
created_by_user_id=None,
sandbox_id='sandbox_with_saas',
title='With SAAS Metadata',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv_without_saas = AppConversationInfo(
id=conv_without_saas_id,
created_by_user_id=None,
sandbox_id='sandbox_without_saas',
title='Without SAAS Metadata',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conv_with_saas)
await app_service_with_saas.save_app_conversation_info(conv_without_saas)
await self._create_saas_metadata(
async_session_with_saas, conv_with_saas_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
# Assert
assert len(result.items) == 2
conv_without = next(
item for item in result.items if item.id == conv_without_saas_id
)
conv_with = next(item for item in result.items if item.id == conv_with_saas_id)
assert conv_without.created_by_user_id is None
assert conv_with.created_by_user_id == str(test_user.id)