mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Refactor grpc channel connection in servicer (#5402)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from _collections_abc import AsyncIterator
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Future, Task
|
||||
from typing import Any, Dict, Sequence, Set, Tuple
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar
|
||||
|
||||
from autogen_core import TopicId
|
||||
from autogen_core._runtime_impl_helpers import SubscriptionManager
|
||||
@@ -38,11 +40,66 @@ async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) ->
|
||||
return client_id # type: ignore
|
||||
|
||||
|
||||
SendT = TypeVar("SendT")
|
||||
ReceiveT = TypeVar("ReceiveT")
|
||||
|
||||
|
||||
class ChannelConnection(ABC, Generic[SendT, ReceiveT]):
|
||||
def __init__(self, request_iterator: AsyncIterator[ReceiveT], client_id: str) -> None:
|
||||
self._request_iterator = request_iterator
|
||||
self._client_id = client_id
|
||||
self._send_queue: asyncio.Queue[SendT] = asyncio.Queue()
|
||||
self._receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
|
||||
|
||||
async def _receive_messages(self, client_id: ClientConnectionId, request_iterator: AsyncIterator[ReceiveT]) -> None:
|
||||
# Receive messages from the client and process them.
|
||||
async for message in request_iterator:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
await self._handle_message(message)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[SendT]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> SendT:
|
||||
try:
|
||||
return await self._send_queue.get()
|
||||
except StopAsyncIteration:
|
||||
await self._receiving_task
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get message from send queue: {e}", exc_info=True)
|
||||
await self._receiving_task
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def _handle_message(self, message: ReceiveT) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, message: SendT) -> None:
|
||||
await self._send_queue.put(message)
|
||||
|
||||
|
||||
class CallbackChannelConnection(ChannelConnection[SendT, ReceiveT]):
|
||||
def __init__(
|
||||
self,
|
||||
request_iterator: AsyncIterator[ReceiveT],
|
||||
client_id: str,
|
||||
handle_callback: Callable[[ReceiveT], Awaitable[None]],
|
||||
) -> None:
|
||||
self._handle_callback = handle_callback
|
||||
super().__init__(request_iterator, client_id)
|
||||
|
||||
async def _handle_message(self, message: ReceiveT) -> None:
|
||||
await self._handle_callback(message)
|
||||
|
||||
|
||||
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {}
|
||||
self._data_connections: Dict[
|
||||
ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]
|
||||
] = {}
|
||||
self._agent_type_to_client_id_lock = asyncio.Lock()
|
||||
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
|
||||
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
|
||||
@@ -57,32 +114,21 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
) -> AsyncIterator[agent_worker_pb2.Message]:
|
||||
client_id = await get_client_id_or_abort(context)
|
||||
|
||||
# Register the client with the server and create a send queue for the client.
|
||||
send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue()
|
||||
self._send_queues[client_id] = send_queue
|
||||
async def handle_callback(message: agent_worker_pb2.Message) -> None:
|
||||
await self._receive_message(client_id, message)
|
||||
|
||||
connection = CallbackChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message](
|
||||
request_iterator, client_id, handle_callback=handle_callback
|
||||
)
|
||||
self._data_connections[client_id] = connection
|
||||
logger.info(f"Client {client_id} connected.")
|
||||
|
||||
try:
|
||||
# Concurrently handle receiving messages from the client and sending messages to the client.
|
||||
# This task will receive messages from the client.
|
||||
receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator))
|
||||
|
||||
# Return an async generator that will yield messages from the send queue to the client.
|
||||
while True:
|
||||
message = await send_queue.get()
|
||||
# Yield the message to the client.
|
||||
try:
|
||||
yield message
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True)
|
||||
break
|
||||
logger.info(f"Sent message to client {client_id}: {message}")
|
||||
# Wait for the receiving task to finish.
|
||||
await receiving_task
|
||||
|
||||
async for message in connection:
|
||||
yield message
|
||||
finally:
|
||||
# Clean up the client connection.
|
||||
del self._send_queues[client_id]
|
||||
del self._data_connections[client_id]
|
||||
# Cancel pending requests sent to this client.
|
||||
for future in self._pending_responses.pop(client_id, {}).values():
|
||||
future.cancel()
|
||||
@@ -105,33 +151,29 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
if exception is not None:
|
||||
raise exception
|
||||
|
||||
async def _receive_messages(
|
||||
self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message]
|
||||
) -> None:
|
||||
# Receive messages from the client and process them.
|
||||
async for message in request_iterator:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
oneofcase = message.WhichOneof("message")
|
||||
match oneofcase:
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
task = asyncio.create_task(self._process_response(response, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "cloudEvent":
|
||||
task = asyncio.create_task(self._process_event(message.cloudEvent))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
async def _receive_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.Message) -> None:
|
||||
logger.info(f"Received message from client {client_id}: {message}")
|
||||
oneofcase = message.WhichOneof("message")
|
||||
match oneofcase:
|
||||
case "request":
|
||||
request: agent_worker_pb2.RpcRequest = message.request
|
||||
task = asyncio.create_task(self._process_request(request, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "response":
|
||||
response: agent_worker_pb2.RpcResponse = message.response
|
||||
task = asyncio.create_task(self._process_response(response, client_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case "cloudEvent":
|
||||
task = asyncio.create_task(self._process_event(message.cloudEvent))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._raise_on_exception)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case None:
|
||||
logger.warning("Received empty message")
|
||||
|
||||
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
|
||||
# Deliver the message to a client given the target agent type.
|
||||
@@ -140,11 +182,11 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
if target_client_id is None:
|
||||
logger.error(f"Agent {request.target.type} not found, failed to deliver message.")
|
||||
return
|
||||
target_send_queue = self._send_queues.get(target_client_id)
|
||||
target_send_queue = self._data_connections.get(target_client_id)
|
||||
if target_send_queue is None:
|
||||
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
|
||||
return
|
||||
await target_send_queue.put(agent_worker_pb2.Message(request=request))
|
||||
await target_send_queue.send(agent_worker_pb2.Message(request=request))
|
||||
|
||||
# Create a future to wait for the response from the target.
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
@@ -161,11 +203,11 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
) -> None:
|
||||
response = await future
|
||||
message = agent_worker_pb2.Message(response=response)
|
||||
send_queue = self._send_queues.get(client_id)
|
||||
send_queue = self._data_connections.get(client_id)
|
||||
if send_queue is None:
|
||||
logger.error(f"Client {client_id} not found, failed to send response message.")
|
||||
return
|
||||
await send_queue.put(message)
|
||||
await send_queue.send(message)
|
||||
|
||||
async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
|
||||
# Setting the result of the future will send the response back to the original sender.
|
||||
@@ -186,7 +228,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
|
||||
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
|
||||
# Deliver the event to clients.
|
||||
for client_id in client_ids:
|
||||
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
|
||||
await self._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event))
|
||||
|
||||
async def RegisterAgent( # type: ignore
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user