mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Initial cross-language protocol for agents (#139)
* Initial prototype of .NET gRPC worker client + service --------- Co-authored-by: Jack Gerrits <jack@jackgerrits.com>
This commit is contained in:
7
python/src/agnext/worker/__init__.py
Normal file
7
python/src/agnext/worker/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
The :mod:`agnext.worker` module provides a set of classes for creating distributed agents
|
||||
"""
|
||||
|
||||
from .worker_runtime import WorkerAgentRuntime
|
||||
|
||||
__all__ = ["WorkerAgentRuntime"]
|
||||
19
python/src/agnext/worker/protos/__init__.py
Normal file
19
python/src/agnext/worker/protos/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
The :mod:`agnext.worker.protos` module provides Google Protobuf classes for agent-worker communication
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .agent_worker_pb2 import Event, Message, RegisterAgentType, RpcRequest, RpcResponse, AgentId
|
||||
from .agent_worker_pb2_grpc import AgentRpcStub
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent_worker_pb2_grpc import AgentRpcAsyncStub
|
||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcAsyncStub", "AgentRpcStub", "Message", "AgentId"]
|
||||
else:
|
||||
__all__ = ["RpcRequest", "RpcResponse", "Event", "RegisterAgentType", "AgentRpcStub", "Message", "AgentId"]
|
||||
397
python/src/agnext/worker/worker_runtime.py
Normal file
397
python/src/agnext/worker/worker_runtime.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from asyncio import Future, Task
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
ClassVar,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
ParamSpec,
|
||||
Set,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from typing_extensions import Self
|
||||
|
||||
from agnext.application.message_serialization import Serialization
|
||||
|
||||
from ..core import (
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentMetadata,
|
||||
AgentProxy,
|
||||
AgentRuntime,
|
||||
CancellationToken,
|
||||
)
|
||||
from .protos import AgentId as AgentIdProto
|
||||
from .protos import (
|
||||
AgentRpcStub,
|
||||
Event,
|
||||
Message,
|
||||
RegisterAgentType,
|
||||
RpcRequest,
|
||||
RpcResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .protos import AgentRpcAsyncStub
|
||||
|
||||
logger = logging.getLogger("agnext")
|
||||
event_logger = logging.getLogger("agnext.events")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PublishMessageEnvelope:
|
||||
"""A message envelope for publishing messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
cancellation_token: CancellationToken
|
||||
sender: AgentId | None
|
||||
namespace: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SendMessageEnvelope:
|
||||
"""A message envelope for sending a message to a specific agent that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
sender: AgentId | None
|
||||
recipient: AgentId
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseMessageEnvelope:
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: Any
|
||||
future: Future[Any]
|
||||
sender: AgentId
|
||||
recipient: AgentId | None
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
|
||||
def __init__(self, queue: asyncio.Queue[Any]) -> None:
|
||||
self._queue = queue
|
||||
|
||||
async def __anext__(self) -> Any:
|
||||
return await self._queue.get()
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[Any]:
|
||||
return self
|
||||
|
||||
|
||||
class RuntimeConnection:
|
||||
DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}],
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 3,
|
||||
"initialBackoff": "0.01s",
|
||||
"maxBackoff": "5s",
|
||||
"backoffMultiplier": 2,
|
||||
"retryableStatusCodes": ["UNAVAILABLE"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
|
||||
self._channel = channel
|
||||
self._send_queue = asyncio.Queue[Message]()
|
||||
self._recv_queue = asyncio.Queue[Message]()
|
||||
self._connection_task: Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_connection_string(
|
||||
cls, connection_string: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG
|
||||
) -> Self:
|
||||
channel = grpc.aio.insecure_channel(
|
||||
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
|
||||
)
|
||||
await channel.channel_ready()
|
||||
instance = cls(channel)
|
||||
instance._connection_task = asyncio.create_task(
|
||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
||||
)
|
||||
return instance
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._channel.close()
|
||||
if self._connection_task is not None:
|
||||
await self._connection_task
|
||||
|
||||
@staticmethod
|
||||
async def _connect( # type: ignore
|
||||
channel: grpc.aio.Channel, send_queue: asyncio.Queue[Message], receive_queue: asyncio.Queue[Message]
|
||||
) -> None:
|
||||
stub: AgentRpcAsyncStub = AgentRpcStub(channel) # type: ignore
|
||||
|
||||
# TODO: where do exceptions from reading the iterable go? How do we recover from those?
|
||||
recv_stream: StreamStreamCall[Message, Message] = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = await recv_stream.read() # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
break
|
||||
message = cast(Message, message)
|
||||
logger.info("Received message: %s", message)
|
||||
await receive_queue.put(message)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
recv_stream = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
await self._send_queue.put(message)
|
||||
|
||||
async def recv(self) -> Message:
|
||||
return await self._recv_queue.get()
|
||||
|
||||
|
||||
class WorkerAgentRuntime(AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set)
|
||||
self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {}
|
||||
# If empty, then all namespaces are valid for that agent type
|
||||
self._valid_namespaces: Dict[str, Sequence[str]] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._known_namespaces: set[str] = set()
|
||||
self._read_task: None | Task[None] = None
|
||||
self._running = False
|
||||
self._serialization = Serialization()
|
||||
self._pending_requests: Dict[str, Future[Any]] = {}
|
||||
self._pending_requests_lock = threading.Lock()
|
||||
self._next_request_id = 0
|
||||
self._runtime_connection: RuntimeConnection | None = None
|
||||
|
||||
async def setup_channel(self, connection_string: str) -> None:
|
||||
self._runtime_connection = await RuntimeConnection.from_connection_string(connection_string)
|
||||
if self._read_task is None:
|
||||
self._read_task = asyncio.create_task(self.run_read_loop())
|
||||
self._running = True
|
||||
|
||||
async def send_register_agent_type(self, agent_type: str) -> None:
|
||||
assert self._runtime_connection is not None
|
||||
message = Message(registerAgentType=RegisterAgentType(type=agent_type))
|
||||
await self._runtime_connection.send(message)
|
||||
logger.info("Sent registerAgentType message for %s", agent_type)
|
||||
|
||||
async def run_read_loop(self) -> None:
|
||||
# TODO: catch exceptions and reconnect
|
||||
while self._running:
|
||||
message = await self._runtime_connection.recv() # type: ignore
|
||||
logger.info("Got message: %s", message)
|
||||
oneofcase = Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType":
|
||||
logger.warn("Cant handle registerAgentType")
|
||||
case "request":
|
||||
# request: RpcRequest = message.request
|
||||
# source = AgentId(request.source.name, request.source.namespace)
|
||||
# target = AgentId(request.target.name, request.target.namespace)
|
||||
|
||||
raise NotImplementedError("Sending messages is not yet implemented.")
|
||||
case "response":
|
||||
response: RpcResponse = message.response
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
break
|
||||
future.set_result(response.result)
|
||||
case "event":
|
||||
event: Event = message.event
|
||||
message = self._serialization.deserialize(event.data, type_name=event.type)
|
||||
namespace = event.namespace
|
||||
|
||||
for agent_id in self._per_type_subscribers[(namespace, type(message))]:
|
||||
agent = self._get_agent(agent_id)
|
||||
try:
|
||||
await agent.on_message(message, CancellationToken())
|
||||
except Exception as e:
|
||||
event_logger.error("Error handling message", exc_info=e)
|
||||
|
||||
logger.warn("Cant handle event")
|
||||
case None:
|
||||
logger.warn("No message")
|
||||
|
||||
async def close_channel(self) -> None:
|
||||
self._running = False
|
||||
if self._runtime_connection is not None:
|
||||
await self._runtime_connection.close()
|
||||
if self._read_task is not None:
|
||||
await self._read_task
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
return set(self._agent_factories.keys())
|
||||
|
||||
# Returns the response of the message
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Any:
|
||||
assert self._runtime_connection is not None
|
||||
# create a new future for the result
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
with self._pending_requests_lock:
|
||||
self._next_request_id += 1
|
||||
request_id = self._next_request_id
|
||||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
sender = cast(AgentId, sender)
|
||||
runtime_message = Message(
|
||||
request=RpcRequest(
|
||||
request_id=request_id_str,
|
||||
target=AgentIdProto(name=recipient.name, namespace=recipient.namespace),
|
||||
source=AgentIdProto(name=sender.name, namespace=sender.namespace),
|
||||
data=message,
|
||||
)
|
||||
)
|
||||
# TODO: Find a way to handle timeouts/errors
|
||||
asyncio.create_task(self._runtime_connection.send(runtime_message))
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> None:
|
||||
assert self._runtime_connection is not None
|
||||
sender_namespace = sender.namespace if sender is not None else None
|
||||
explicit_namespace = namespace
|
||||
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace:
|
||||
raise ValueError(
|
||||
f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}"
|
||||
)
|
||||
|
||||
assert explicit_namespace is not None or sender_namespace is not None
|
||||
actual_namespace = cast(str, explicit_namespace or sender_namespace)
|
||||
self._process_seen_namespace(actual_namespace)
|
||||
message_type = self._serialization.type_name(message)
|
||||
serialized_message = self._serialization.serialize(message, type_name=message_type)
|
||||
message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message))
|
||||
|
||||
async def write_message() -> None:
|
||||
assert self._runtime_connection is not None
|
||||
await self._runtime_connection.send(message)
|
||||
|
||||
await asyncio.create_task(write_message())
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Saving state is not yet implemented.")
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Loading state is not yet implemented.")
|
||||
|
||||
def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
raise NotImplementedError("Agent metadata is not yet implemented.")
|
||||
|
||||
def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
raise NotImplementedError("Agent save_state is not yet implemented.")
|
||||
|
||||
def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
raise NotImplementedError("Agent load_state is not yet implemented.")
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
) -> None:
|
||||
if name in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {name} already exists.")
|
||||
self._agent_factories[name] = agent_factory
|
||||
|
||||
# For all already prepared namespaces we need to prepare this agent
|
||||
for namespace in self._known_namespaces:
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
# TODO do we need to convert register to async?
|
||||
asyncio.create_task(self.send_register_agent_type(name))
|
||||
|
||||
def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
return agent
|
||||
|
||||
def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
self._process_seen_namespace(agent_id.namespace)
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if agent_id.name not in self._agent_factories:
|
||||
raise ValueError(f"Agent with name {agent_id.name} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.name]
|
||||
|
||||
agent = self._invoke_agent_factory(agent_factory, agent_id)
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
||||
self._serialization.add_type(message_type)
|
||||
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
def get(self, name: str, *, namespace: str = "default") -> AgentId:
|
||||
return self._get_agent(AgentId(name=name, namespace=namespace)).id
|
||||
|
||||
def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy:
|
||||
id = self.get(name, namespace=namespace)
|
||||
return AgentProxy(id, self)
|
||||
|
||||
# Hydrate the agent instances in a namespace. The primary reason for this is
|
||||
# to ensure message type subscriptions are set up.
|
||||
def _process_seen_namespace(self, namespace: str) -> None:
|
||||
if namespace in self._known_namespaces:
|
||||
return
|
||||
|
||||
self._known_namespaces.add(namespace)
|
||||
for name in self._known_agent_names:
|
||||
self._get_agent(AgentId(name=name, namespace=namespace))
|
||||
|
||||
def add_serialization_type(self, message_type: type) -> None:
|
||||
self._serialization.add_type(message_type)
|
||||
Reference in New Issue
Block a user