Compare commits

...

2 Commits

Author SHA1 Message Date
openhands
1d31927230 Update tests to work with removal of github_user_id parameter 2025-04-25 03:42:13 +00:00
Chuck Butkus
9a24df14f1 Remove GItHub ID where it's not needed 2025-04-24 23:02:07 -04:00
5 changed files with 16 additions and 38 deletions

View File

@@ -53,7 +53,6 @@ class ConversationManager(ABC):
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStore | None:
"""Join a conversation and return its event stream."""
@@ -82,7 +81,6 @@ class ConversationManager(ABC):
user_id: str | None,
initial_user_msg: MessageAction | None = None,
replay_json: str | None = None,
github_user_id: str | None = None,
) -> EventStore:
"""Start an event loop if one is not already running"""

View File

@@ -115,7 +115,6 @@ class StandaloneConversationManager(ConversationManager):
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStore:
logger.info(
f'join_conversation:{sid}:{connection_id}',
@@ -123,9 +122,7 @@ class StandaloneConversationManager(ConversationManager):
)
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
)
event_stream = await self.maybe_start_agent_loop(sid, settings, user_id)
if not event_stream:
logger.error(
f'No event stream after joining conversation: {sid}',
@@ -193,9 +190,7 @@ class StandaloneConversationManager(ConversationManager):
logger.error('error_cleaning_stale')
await asyncio.sleep(_CLEANUP_INTERVAL)
async def _get_conversation_store(
self, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
conversation_store_class = self._conversation_store_class
if not conversation_store_class:
self._conversation_store_class = conversation_store_class = get_impl(
@@ -252,12 +247,11 @@ class StandaloneConversationManager(ConversationManager):
user_id: str | None,
initial_user_msg: MessageAction | None = None,
replay_json: str | None = None,
github_user_id: str | None = None,
) -> EventStore:
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
if not await self.is_agent_loop_running(sid):
await self._start_agent_loop(
sid, settings, user_id, initial_user_msg, replay_json, github_user_id
sid, settings, user_id, initial_user_msg, replay_json
)
event_store = await self._get_event_store(sid, user_id)
@@ -276,7 +270,6 @@ class StandaloneConversationManager(ConversationManager):
user_id: str | None,
initial_user_msg: MessageAction | None = None,
replay_json: str | None = None,
github_user_id: str | None = None,
) -> Session:
logger.info(f'starting_agent_loop:{sid}', extra={'session_id': sid})
@@ -287,9 +280,7 @@ class StandaloneConversationManager(ConversationManager):
extra={'session_id': sid, 'user_id': user_id},
)
# Get the conversations sorted (oldest first)
conversation_store = await self._get_conversation_store(
user_id, github_user_id
)
conversation_store = await self._get_conversation_store(user_id)
conversations = await conversation_store.get_all_metadata(response_ids)
conversations.sort(key=_last_updated_at_key, reverse=True)
@@ -328,7 +319,7 @@ class StandaloneConversationManager(ConversationManager):
try:
session.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER,
self._create_conversation_update_callback(user_id, github_user_id, sid),
self._create_conversation_update_callback(user_id, sid),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@@ -425,14 +416,13 @@ class StandaloneConversationManager(ConversationManager):
)
def _create_conversation_update_callback(
self, user_id: str | None, github_user_id: str | None, conversation_id: str
self, user_id: str | None, conversation_id: str
) -> Callable:
def callback(event, *args, **kwargs):
call_async_from_sync(
self._update_conversation_for_event,
GENERAL_TIMEOUT,
user_id,
github_user_id,
conversation_id,
event,
)
@@ -440,9 +430,9 @@ class StandaloneConversationManager(ConversationManager):
return callback
async def _update_conversation_for_event(
self, user_id: str, github_user_id: str, conversation_id: str, event=None
self, user_id: str, conversation_id: str, event=None
):
conversation_store = await self._get_conversation_store(user_id, github_user_id)
conversation_store = await self._get_conversation_store(user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)

View File

@@ -60,9 +60,7 @@ async def connect(connection_id: str, environ):
cookies_str = environ.get('HTTP_COOKIE', '')
conversation_validator = create_conversation_validator()
user_id, github_user_id = await conversation_validator.validate(
conversation_id, cookies_str
)
user_id = await conversation_validator.validate(conversation_id, cookies_str)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
@@ -81,7 +79,7 @@ async def connect(connection_id: str, environ):
conversation_init_data = ConversationInitData(**session_init_args)
event_stream = await conversation_manager.join_conversation(
conversation_id, connection_id, conversation_init_data, user_id, github_user_id
conversation_id, connection_id, conversation_init_data, user_id
)
logger.info(
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'

View File

@@ -22,15 +22,10 @@ class ConversationStore(ABC):
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
"""Load conversation metadata."""
async def validate_metadata(
self, conversation_id: str, user_id: str, github_user_id: str
) -> bool:
async def validate_metadata(self, conversation_id: str, user_id: str) -> bool:
"""Validate that conversation belongs to the current user."""
# TODO: remove github_user_id after transition to Keycloak is complete.
metadata = await self.get_metadata(conversation_id)
if (not metadata.user_id and not metadata.github_user_id) or (
metadata.user_id != user_id and metadata.github_user_id != github_user_id
):
if not metadata.user_id or metadata.user_id != user_id:
return False
else:
return True

View File

@@ -73,8 +73,7 @@ async def test_init_new_local_session():
'new-session-id',
'new-session-id',
ConversationInitData(),
1,
'12345',
'12345'
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 1
@@ -118,15 +117,13 @@ async def test_join_local_session():
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
None
)
await conversation_manager.join_conversation(
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
None
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 2
@@ -159,7 +156,7 @@ async def test_add_to_local_event_stream():
'new-session-id', ConversationInitData(), 1
)
await conversation_manager.join_conversation(
'new-session-id', 'connection-id', ConversationInitData(), 1, '12345'
'new-session-id', 'connection-id', ConversationInitData(), 1
)
await conversation_manager.send_to_event_stream(
'connection-id', {'event_type': 'some_event'}