APP-307 Add Google Cloud Storage-based EventService implementation (#12264)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2026-01-06 15:52:07 -07:00
committed by GitHub
parent af5c22700c
commit fa974f8106
13 changed files with 447 additions and 435 deletions

View File

@@ -11,10 +11,15 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from google.cloud import storage
from google.cloud.storage.bucket import Bucket
from google.cloud.storage.client import Client
from more_itertools import bucket
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
@@ -28,6 +33,9 @@ from server.sharing.sql_shared_conversation_info_service import (
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.app_server.event.google_cloud_event_service import (
GoogleCloudEventService,
)
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
@@ -36,17 +44,13 @@ logger = logging.getLogger(__name__)
@dataclass
class SharedEventServiceImpl(SharedEventService):
class GoogleCloudSharedEventService(SharedEventService):
"""Implementation of SharedEventService that validates shared access."""
shared_conversation_info_service: SharedConversationInfoService
event_service: EventService
bucket: Bucket
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
# First check if the conversation is shared
async def get_event_service(self, conversation_id: UUID) -> EventService | None:
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
@@ -55,8 +59,25 @@ class SharedEventServiceImpl(SharedEventService):
if shared_conversation_info is None:
return None
return GoogleCloudEventService(
bucket=bucket,
prefix=Path('users'),
user_id=shared_conversation_info.created_by_user_id,
app_conversation_info_service=None,
app_conversation_info_load_tasks={},
)
async def get_shared_event(
self, conversation_id: UUID, event_id: UUID
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
# First check if the conversation is shared
event_service = await self.get_event_service(conversation_id)
if event_service is None:
return None
# If conversation is shared, get the event
return await self.event_service.get_event(event_id)
return await event_service.get_event(conversation_id, event_id)
async def search_shared_events(
self,
@@ -70,18 +91,14 @@ class SharedEventServiceImpl(SharedEventService):
) -> EventPage:
"""Search events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
event_service = await self.get_event_service(conversation_id)
if event_service is None:
# Return empty page if conversation is not shared
return EventPage(items=[], next_page_id=None)
# If conversation is shared, search events for this conversation
return await self.event_service.search_events(
conversation_id__eq=conversation_id,
return await event_service.search_events(
conversation_id=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
@@ -96,47 +113,41 @@ class SharedEventServiceImpl(SharedEventService):
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
event_service = await self.get_event_service(conversation_id)
if event_service is None:
# Return empty page if conversation is not shared
return 0
# If conversation is shared, count events for this conversation
return await self.event_service.count_events(
conversation_id__eq=conversation_id,
return await event_service.count_events(
conversation_id=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
class SharedEventServiceImplInjector(SharedEventServiceInjector):
class GoogleCloudSharedEventServiceInjector(SharedEventServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedEventService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import (
get_db_session,
get_event_service,
)
from openhands.app_server.config import get_db_session
async with (
get_db_session(state, request) as db_session,
get_event_service(state, request) as event_service,
):
async with get_db_session(state, request) as db_session:
shared_conversation_info_service = SQLSharedConversationInfoService(
db_session=db_session
)
service = SharedEventServiceImpl(
bucket_name = self.bucket_name
storage_client: Client = storage.Client()
bucket: Bucket = storage_client.bucket(bucket_name)
service = GoogleCloudSharedEventService(
shared_conversation_info_service=shared_conversation_info_service,
event_service=event_service,
bucket=bucket,
)
yield service

View File

@@ -5,8 +5,8 @@ from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImplInjector,
from server.sharing.google_cloud_shared_event_service import (
GoogleCloudSharedEventServiceInjector,
)
from server.sharing.shared_event_service import SharedEventService
@@ -15,7 +15,9 @@ from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
shared_event_service_dependency = Depends(
GoogleCloudSharedEventServiceInjector().depends
)
# Read methods
@@ -85,10 +87,6 @@ async def count_shared_events(
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> int:
"""Count events for a shared conversation matching the given filters."""
@@ -97,14 +95,13 @@ async def count_shared_events(
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
@router.get('')
async def batch_get_shared_events(
conversation_id: Annotated[
UUID,
str,
Query(title='Conversation ID to get events for'),
],
id: Annotated[list[str], Query()],
@@ -112,15 +109,20 @@ async def batch_get_shared_events(
) -> list[Event | None]:
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
assert len(id) <= 100
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
event_ids = [UUID(id_) for id_ in id]
events = await shared_event_service.batch_get_shared_events(
UUID(conversation_id), event_ids
)
return events
@router.get('/{conversation_id}/{event_id}')
async def get_shared_event(
conversation_id: UUID,
conversation_id: str,
event_id: str,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> Event | None:
"""Get a single event from a shared conversation by conversation_id and event_id."""
return await shared_event_service.get_shared_event(conversation_id, event_id)
return await shared_event_service.get_shared_event(
UUID(conversation_id), UUID(event_id)
)

View File

@@ -18,7 +18,7 @@ class SharedEventService(ABC):
@abstractmethod
async def get_shared_event(
self, conversation_id: UUID, event_id: str
self, conversation_id: UUID, event_id: UUID
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
@@ -42,12 +42,11 @@ class SharedEventService(ABC):
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
async def batch_get_shared_events(
self, conversation_id: UUID, event_ids: list[str]
self, conversation_id: UUID, event_ids: list[UUID]
) -> list[Event | None]:
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
return await asyncio.gather(

View File

@@ -1,12 +1,12 @@
"""Tests for SharedEventService."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImpl,
from server.sharing.google_cloud_shared_event_service import (
GoogleCloudSharedEventService,
)
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
@@ -25,18 +25,24 @@ def mock_shared_conversation_info_service():
return AsyncMock(spec=SharedConversationInfoService)
@pytest.fixture
def mock_bucket():
"""Create a mock GCS bucket."""
return MagicMock()
@pytest.fixture
def mock_event_service():
"""Create a mock EventService."""
"""Create a mock EventService for returned by get_event_service."""
return AsyncMock(spec=EventService)
@pytest.fixture
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
def shared_event_service(mock_shared_conversation_info_service, mock_bucket):
"""Create a SharedEventService for testing."""
return SharedEventServiceImpl(
return GoogleCloudSharedEventService(
shared_conversation_info_service=mock_shared_conversation_info_service,
event_service=mock_event_service,
bucket=mock_bucket,
)
@@ -79,11 +85,16 @@ class TestSharedEventService:
):
"""Test that get_shared_event returns an event for a public conversation."""
conversation_id = sample_public_conversation.id
event_id = 'test_event_id'
event_id = uuid4()
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock get_event_service to return our mock event service
shared_event_service.get_event_service = AsyncMock(
return_value=mock_event_service
)
# Mock the event service to return an event
mock_event_service.get_event.return_value = sample_event
@@ -92,10 +103,8 @@ class TestSharedEventService:
# Verify the result
assert result == sample_event
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.get_event.assert_called_once_with(event_id)
shared_event_service.get_event_service.assert_called_once_with(conversation_id)
mock_event_service.get_event.assert_called_once_with(conversation_id, event_id)
async def test_get_shared_event_returns_none_for_private_conversation(
self,
@@ -105,20 +114,18 @@ class TestSharedEventService:
):
"""Test that get_shared_event returns None for a private conversation."""
conversation_id = uuid4()
event_id = 'test_event_id'
event_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Mock get_event_service to return None (private conversation)
shared_event_service.get_event_service = AsyncMock(return_value=None)
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
shared_event_service.get_event_service.assert_called_once_with(conversation_id)
# Event service should not be called since get_event_service returns None
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_returns_events_for_public_conversation(
@@ -132,8 +139,10 @@ class TestSharedEventService:
"""Test that search_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock get_event_service to return our mock event service
shared_event_service.get_event_service = AsyncMock(
return_value=mock_event_service
)
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id=None)
@@ -150,11 +159,9 @@ class TestSharedEventService:
assert result == mock_event_page
assert len(result.items) == 0 # Empty list as we mocked
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
shared_event_service.get_event_service.assert_called_once_with(conversation_id)
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
conversation_id=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
@@ -172,8 +179,8 @@ class TestSharedEventService:
"""Test that search_shared_events returns empty page for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Mock get_event_service to return None (private conversation)
shared_event_service.get_event_service = AsyncMock(return_value=None)
# Call the method
result = await shared_event_service.search_shared_events(
@@ -186,9 +193,7 @@ class TestSharedEventService:
assert len(result.items) == 0
assert result.next_page_id is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
shared_event_service.get_event_service.assert_called_once_with(conversation_id)
# Event service should not be called
mock_event_service.search_events.assert_not_called()
@@ -202,8 +207,10 @@ class TestSharedEventService:
"""Test that count_shared_events returns count for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock get_event_service to return our mock event service
shared_event_service.get_event_service = AsyncMock(
return_value=mock_event_service
)
# Mock the event service to return a count
mock_event_service.count_events.return_value = 5
@@ -217,15 +224,12 @@ class TestSharedEventService:
# Verify the result
assert result == 5
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
shared_event_service.get_event_service.assert_called_once_with(conversation_id)
mock_event_service.count_events.assert_called_once_with(
conversation_id__eq=conversation_id,
conversation_id=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
)
async def test_count_shared_events_returns_zero_for_private_conversation(
@@ -237,8 +241,8 @@ class TestSharedEventService:
"""Test that count_shared_events returns 0 for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Mock get_event_service to return None (private conversation)
shared_event_service.get_event_service = AsyncMock(return_value=None)
# Call the method
result = await shared_event_service.count_shared_events(
@@ -248,9 +252,7 @@ class TestSharedEventService:
# Verify the result
assert result == 0
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
shared_event_service.get_event_service.assert_called_once_with(conversation_id)
# Event service should not be called
mock_event_service.count_events.assert_not_called()
@@ -264,10 +266,12 @@ class TestSharedEventService:
):
"""Test that batch_get_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
event_ids = ['event1', 'event2']
event_ids = [uuid4(), uuid4()]
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock get_event_service to return our mock event service
shared_event_service.get_event_service = AsyncMock(
return_value=mock_event_service
)
# Mock the event service to return events
mock_event_service.get_event.side_effect = [sample_event, None]
@@ -282,11 +286,8 @@ class TestSharedEventService:
assert result[0] == sample_event
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Verify that get_event_service was called for each event
assert shared_event_service.get_event_service.call_count == 2
# Verify that get_event was called for each event
assert mock_event_service.get_event.call_count == 2
@@ -298,10 +299,10 @@ class TestSharedEventService:
):
"""Test that batch_get_shared_events returns None for a private conversation."""
conversation_id = uuid4()
event_ids = ['event1', 'event2']
event_ids = [uuid4(), uuid4()]
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Mock get_event_service to return None (private conversation)
shared_event_service.get_event_service = AsyncMock(return_value=None)
# Call the method
result = await shared_event_service.batch_get_shared_events(
@@ -313,11 +314,8 @@ class TestSharedEventService:
assert result[0] is None
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Verify that get_event_service was called for each event
assert shared_event_service.get_event_service.call_count == 2
# Event service should not be called
mock_event_service.get_event.assert_not_called()
@@ -333,8 +331,10 @@ class TestSharedEventService:
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock get_event_service to return our mock event service
shared_event_service.get_event_service = AsyncMock(
return_value=mock_event_service
)
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id='next_page')
@@ -355,7 +355,7 @@ class TestSharedEventService:
assert result == mock_event_page
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
conversation_id=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,