Refactor grpc channel connection in servicer (#5402)

This commit is contained in:
Jack Gerrits
2025-02-06 13:53:24 -05:00
committed by GitHub
parent cf798aef3f
commit ca428914f5

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import logging
from _collections_abc import AsyncIterator
from abc import ABC, abstractmethod
from asyncio import Future, Task
from typing import Any, Dict, Sequence, Set, Tuple
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar
from autogen_core import TopicId
from autogen_core._runtime_impl_helpers import SubscriptionManager
@@ -38,11 +40,66 @@ async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) ->
return client_id # type: ignore
SendT = TypeVar("SendT")
ReceiveT = TypeVar("ReceiveT")
class ChannelConnection(ABC, Generic[SendT, ReceiveT]):
def __init__(self, request_iterator: AsyncIterator[ReceiveT], client_id: str) -> None:
self._request_iterator = request_iterator
self._client_id = client_id
self._send_queue: asyncio.Queue[SendT] = asyncio.Queue()
self._receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
async def _receive_messages(self, client_id: ClientConnectionId, request_iterator: AsyncIterator[ReceiveT]) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
logger.info(f"Received message from client {client_id}: {message}")
await self._handle_message(message)
def __aiter__(self) -> AsyncIterator[SendT]:
return self
async def __anext__(self) -> SendT:
try:
return await self._send_queue.get()
except StopAsyncIteration:
await self._receiving_task
raise
except Exception as e:
logger.error(f"Failed to get message from send queue: {e}", exc_info=True)
await self._receiving_task
raise
@abstractmethod
async def _handle_message(self, message: ReceiveT) -> None:
pass
async def send(self, message: SendT) -> None:
await self._send_queue.put(message)
class CallbackChannelConnection(ChannelConnection[SendT, ReceiveT]):
def __init__(
self,
request_iterator: AsyncIterator[ReceiveT],
client_id: str,
handle_callback: Callable[[ReceiveT], Awaitable[None]],
) -> None:
self._handle_callback = handle_callback
super().__init__(request_iterator, client_id)
async def _handle_message(self, message: ReceiveT) -> None:
await self._handle_callback(message)
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:
self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
self._data_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]
] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
@@ -57,32 +114,21 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
) -> AsyncIterator[agent_worker_pb2.Message]:
client_id = await get_client_id_or_abort(context)
# Register the client with the server and create a send queue for the client.
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
self._send_queues[client_id] = send_queue
async def handle_callback(message: agent_worker_pb2.Message) -> None:
await self._receive_message(client_id, message)
connection = CallbackChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message](
request_iterator, client_id, handle_callback=handle_callback
)
self._data_connections[client_id] = connection
logger.info(f"Client {client_id} connected.")
try:
# Concurrently handle receiving messages from the client and sending messages to the client.
# This task will receive messages from the client.
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
# Return an async generator that will yield messages from the send queue to the client.
while True:
message = await send_queue.get()
# Yield the message to the client.
try:
yield message
except Exception as e:
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
break
logger.info(f"Sent message to client {client_id}: {message}")
# Wait for the receiving task to finish.
await receiving_task
async for message in connection:
yield message
finally:
# Clean up the client connection.
del self._send_queues[client_id]
del self._data_connections[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
@@ -105,33 +151,29 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
if exception is not None:
raise exception
async def _receive_messages(
self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
) -> None:
# Receive messages from the client and process them.
async for message in request_iterator:
logger.info(f"Received message from client {client_id}: {message}")
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")
async def _receive_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.Message) -> None:
logger.info(f"Received message from client {client_id}: {message}")
oneofcase = message.WhichOneof("message")
match oneofcase:
case "request":
request: agent_worker_pb2.RpcRequest = message.request
task = asyncio.create_task(self._process_request(request, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "response":
response: agent_worker_pb2.RpcResponse = message.response
task = asyncio.create_task(self._process_response(response, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "cloudEvent":
task = asyncio.create_task(self._process_event(message.cloudEvent))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("Received empty message")
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type.
@@ -140,11 +182,11 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
if target_client_id is None:
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
return
target_send_queue = self._send_queues.get(target_client_id)
target_send_queue = self._data_connections.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.put(agent_worker_pb2.Message(request=request))
await target_send_queue.send(agent_worker_pb2.Message(request=request))
# Create a future to wait for the response from the target.
future = asyncio.get_event_loop().create_future()
@@ -161,11 +203,11 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
) -> None:
response = await future
message = agent_worker_pb2.Message(response=response)
send_queue = self._send_queues.get(client_id)
send_queue = self._data_connections.get(client_id)
if send_queue is None:
logger.error(f"Client {client_id} not found, failed to send response message.")
return
await send_queue.put(message)
await send_queue.send(message)
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
# Setting the result of the future will send the response back to the original sender.
@@ -186,7 +228,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
await self._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event))
async def RegisterAgent( # type: ignore
self,