mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Initial impl of new register and subscriptions (#542)
* Initial impl of new register and subscriptions * progress * test fixes, main issue was unbound self in routed agent
This commit is contained in:
@@ -121,16 +121,16 @@
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"from autogen_core.application import WorkerAgentRuntime\n",
|
||||
"from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, try_get_known_serializers_for_type\n",
|
||||
"from autogen_core.base import AgentId, try_get_known_serializers_for_type\n",
|
||||
"from autogen_core.components import DefaultSubscription\n",
|
||||
"\n",
|
||||
"MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MyMessage))\n",
|
||||
"\n",
|
||||
"worker1 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker1.add_message_serializer(try_get_known_serializers_for_type(MyMessage))\n",
|
||||
"worker1.start()\n",
|
||||
"await worker1.register(\"worker1\", lambda: MyAgent(\"worker1\"), lambda: [DefaultSubscription()])\n",
|
||||
"\n",
|
||||
"worker2 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker2.add_message_serializer(try_get_known_serializers_for_type(MyMessage))\n",
|
||||
"worker2.start()\n",
|
||||
"await worker2.register(\"worker2\", lambda: MyAgent(\"worker2\"), lambda: [DefaultSubscription()])\n",
|
||||
"\n",
|
||||
|
||||
@@ -4,8 +4,8 @@ from dataclasses import dataclass
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext, try_get_known_serializers_for_type
|
||||
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -33,6 +33,7 @@ class ReturnedFeedback:
|
||||
content: str
|
||||
|
||||
|
||||
@default_subscription
|
||||
class ReceiveAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Receive Agent")
|
||||
@@ -49,6 +50,7 @@ class ReceiveAgent(RoutedAgent):
|
||||
print(f"Unhandled message: {message}")
|
||||
|
||||
|
||||
@default_subscription
|
||||
class GreeterAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Greeter Agent")
|
||||
@@ -67,15 +69,10 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedGreeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedFeedback))
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()])
|
||||
await runtime.register("greeter", GreeterAgent, lambda: [DefaultSubscription()])
|
||||
await ReceiveAgent.register(runtime, "receiver", ReceiveAgent)
|
||||
await GreeterAgent.register(runtime, "greeter", GreeterAgent)
|
||||
|
||||
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, NoReturn
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.base import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
MessageContext,
|
||||
@@ -61,9 +60,6 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])
|
||||
|
||||
@@ -12,6 +12,9 @@ from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
|
||||
|
||||
from ..base import (
|
||||
Agent,
|
||||
@@ -163,6 +166,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._run_context: RunContext | None = None
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
@@ -522,6 +526,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
(await self._get_agent(agent)).load_state(state)
|
||||
|
||||
@deprecated(
|
||||
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
|
||||
)
|
||||
async def register(
|
||||
self,
|
||||
type: str,
|
||||
@@ -550,6 +557,29 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._agent_factories[type] = agent_factory
|
||||
return AgentType(type)
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
*,
|
||||
type: AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
agent_instance = await maybe_agent_instance
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
return type
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
@@ -616,3 +646,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
lazy=lazy,
|
||||
instance_getter=self._get_agent,
|
||||
)
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
self._serialization_registry.add_serializer(serializer)
|
||||
|
||||
@@ -30,12 +30,12 @@ from typing import (
|
||||
import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from opentelemetry.trace import NoOpTracerProvider, TracerProvider
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from autogen_core.base import JSON_DATA_CONTENT_TYPE
|
||||
from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
|
||||
|
||||
from ..base import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
@@ -63,6 +63,8 @@ event_logger = logging.getLogger("autogen_core.events")
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
type_func_alias = type
|
||||
|
||||
|
||||
class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
|
||||
def __init__(self, queue: asyncio.Queue[Any]) -> None:
|
||||
@@ -166,6 +168,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
self._host_connection: HostConnection | None = None
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the runtime in a background task."""
|
||||
@@ -286,7 +289,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
raise ValueError("Runtime must be running when sending message.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
data_type = self._serialization_registry.type_name(message)
|
||||
with self._trace_helper.trace_block(
|
||||
"create", recipient, parent=None, extraAttributes={"message_type": data_type, "message_size": len(message)}
|
||||
):
|
||||
@@ -297,7 +300,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
request_id = self._next_request_id
|
||||
request_id_str = str(request_id)
|
||||
self._pending_requests[request_id_str] = future
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
serialized_message = self._serialization_registry.serialize(
|
||||
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
telemetry_metadata = get_telemetry_grpc_metadata()
|
||||
@@ -334,11 +337,11 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
raise ValueError("Runtime must be running when publishing message.")
|
||||
if self._host_connection is None:
|
||||
raise RuntimeError("Host connection is not set.")
|
||||
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
|
||||
message_type = self._serialization_registry.type_name(message)
|
||||
with self._trace_helper.trace_block(
|
||||
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
|
||||
):
|
||||
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
serialized_message = self._serialization_registry.serialize(
|
||||
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
telemetry_metadata = get_telemetry_grpc_metadata()
|
||||
@@ -387,7 +390,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
logging.info(f"Processing request from unknown source to {recipient}")
|
||||
|
||||
# Deserialize the message.
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
message = self._serialization_registry.deserialize(
|
||||
request.payload.data,
|
||||
type_name=request.payload.data_type,
|
||||
data_content_type=request.payload.data_content_type,
|
||||
@@ -426,8 +429,8 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
return
|
||||
|
||||
# Serialize the result.
|
||||
result_type = MESSAGE_TYPE_REGISTRY.type_name(result)
|
||||
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(
|
||||
result_type = self._serialization_registry.type_name(result)
|
||||
serialized_result = self._serialization_registry.serialize(
|
||||
result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
)
|
||||
|
||||
@@ -456,7 +459,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
extraAttributes={"message_type": response.payload.data_type},
|
||||
):
|
||||
# Deserialize the result.
|
||||
result = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
result = self._serialization_registry.deserialize(
|
||||
response.payload.data,
|
||||
type_name=response.payload.data_type,
|
||||
data_content_type=response.payload.data_content_type,
|
||||
@@ -469,7 +472,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
future.set_result(result)
|
||||
|
||||
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(
|
||||
message = self._serialization_registry.deserialize(
|
||||
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
|
||||
)
|
||||
sender: AgentId | None = None
|
||||
@@ -509,6 +512,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
except BaseException as e:
|
||||
logger.error("Error handling event", exc_info=e)
|
||||
|
||||
@deprecated(
|
||||
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
|
||||
)
|
||||
async def register(
|
||||
self,
|
||||
type: str,
|
||||
@@ -542,6 +548,29 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
return AgentType(type)
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
*,
|
||||
type: AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
agent_instance = await maybe_agent_instance
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
return type
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
@@ -622,3 +651,6 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
lazy=lazy,
|
||||
instance_getter=self._get_agent,
|
||||
)
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
self._serialization_registry.add_serializer(serializer)
|
||||
|
||||
@@ -10,15 +10,14 @@ from ._agent_props import AgentChildren
|
||||
from ._agent_proxy import AgentProxy
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent
|
||||
from ._base_agent import BaseAgent, subscription_factory
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._serialization import (
|
||||
JSON_DATA_CONTENT_TYPE,
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
MessageSerializer,
|
||||
Serialization,
|
||||
SerializationRegistry,
|
||||
UnknownPayload,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
@@ -36,11 +35,10 @@ __all__ = [
|
||||
"CancellationToken",
|
||||
"AgentChildren",
|
||||
"AgentInstantiationContext",
|
||||
"MESSAGE_TYPE_REGISTRY",
|
||||
"TopicId",
|
||||
"Subscription",
|
||||
"MessageContext",
|
||||
"Serialization",
|
||||
"SerializationRegistry",
|
||||
"AgentType",
|
||||
"SubscriptionInstantiationContext",
|
||||
"MessageHandlerContext",
|
||||
@@ -48,4 +46,5 @@ __all__ = [
|
||||
"MessageSerializer",
|
||||
"try_get_known_serializers_for_type",
|
||||
"UnknownPayload",
|
||||
"subscription_factory",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._serialization import MessageSerializer
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
@@ -67,6 +71,9 @@ class AgentRuntime(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@deprecated(
|
||||
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
|
||||
)
|
||||
async def register(
|
||||
self,
|
||||
type: str,
|
||||
@@ -82,6 +89,34 @@ class AgentRuntime(Protocol):
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.base.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
subscriptions (Callable[[], list[Subscription]] | list[Subscription] | None, optional): The subscriptions that the agent should be subscribed to. Defaults to None.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
runtime.register(
|
||||
"chat_agent",
|
||||
lambda: ChatCompletionAgent(
|
||||
description="A generic chat agent.",
|
||||
system_messages=[SystemMessage("You are a helpful assistant")],
|
||||
model_client=OpenAIChatCompletionClient(model="gpt-4o"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
),
|
||||
)
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
*,
|
||||
type: AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
|
||||
|
||||
Args:
|
||||
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.base.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -97,7 +132,6 @@ class AgentRuntime(Protocol):
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
@@ -199,3 +233,13 @@ class AgentRuntime(Protocol):
|
||||
LookupError: If the subscription does not exist
|
||||
"""
|
||||
...
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
"""Add a new message serialization serializer to the runtime
|
||||
|
||||
Note: This will deduplicate serializers based on the type_name and data_content_type properties
|
||||
|
||||
Args:
|
||||
serializer (MessageSerializer[Any] | Sequence[MessageSerializer[Any]]): The serializer/s to add
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -1,18 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping
|
||||
from collections.abc import Sequence
|
||||
from re import S
|
||||
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, overload
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
|
||||
from ._subscription import UnboundSubscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
|
||||
# Decorator for adding an unbound subscription to an agent
|
||||
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
cls._unbound_subscriptions_list.append(subscription)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handles(
|
||||
type: Type[Any], serializer: MessageSerializer[Any] | List[MessageSerializer[Any]] | None = None
|
||||
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
|
||||
if serializer is None:
|
||||
serializer_list = try_get_known_serializers_for_type(type)
|
||||
else:
|
||||
serializer_list = [serializer] if not isinstance(serializer, Sequence) else serializer
|
||||
|
||||
if len(serializer_list) == 0:
|
||||
raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")
|
||||
|
||||
cls._extra_handles_types.append((type, serializer_list))
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BaseAgent(ABC, Agent):
|
||||
_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
|
||||
_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
return cls._extra_handles_types
|
||||
|
||||
@classmethod
|
||||
def _unbound_subscriptions(cls) -> List[UnboundSubscription]:
|
||||
return cls._unbound_subscriptions_list
|
||||
|
||||
@property
|
||||
def metadata(self) -> AgentMetadata:
|
||||
assert self._id is not None
|
||||
@@ -82,3 +135,29 @@ class BaseAgent(ABC, Agent):
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
warnings.warn("load_state not implemented", stacklevel=2)
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def register(
|
||||
cls, runtime: AgentRuntime, type: str, factory: Callable[[], Self | Awaitable[Self]]
|
||||
) -> AgentType:
|
||||
agent_type = AgentType(type)
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions = []
|
||||
for unbound_subscription in cls._unbound_subscriptions():
|
||||
subscriptions_list_result = unbound_subscription()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
|
||||
subscriptions.extend(subscriptions_list)
|
||||
|
||||
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
|
||||
for subscription in subscriptions:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
# TODO: deduplication
|
||||
for _message_type, serializer in cls._handles_types():
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_type
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from autogen_core.base._type_helpers import is_union
|
||||
|
||||
@@ -171,13 +172,13 @@ def try_get_known_serializers_for_type(cls: type[Any]) -> list[MessageSerializer
|
||||
return serializers
|
||||
|
||||
|
||||
class Serialization:
|
||||
class SerializationRegistry:
|
||||
def __init__(self) -> None:
|
||||
# type_name, data_content_type -> serializer
|
||||
self._serializers: dict[tuple[str, str], MessageSerializer[Any]] = {}
|
||||
|
||||
def add_serializer(self, serializer: MessageSerializer[Any] | List[MessageSerializer[Any]]) -> None:
|
||||
if isinstance(serializer, list):
|
||||
def add_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
if isinstance(serializer, Sequence):
|
||||
for c in serializer:
|
||||
self.add_serializer(c)
|
||||
return
|
||||
@@ -203,6 +204,3 @@ class Serialization:
|
||||
|
||||
def type_name(self, message: Any) -> str:
|
||||
return _type_name(message)
|
||||
|
||||
|
||||
MESSAGE_TYPE_REGISTRY = Serialization()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Protocol, runtime_checkable
|
||||
from __future__ import annotations
|
||||
|
||||
from autogen_core.base import AgentId
|
||||
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
|
||||
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
@@ -58,3 +60,7 @@ class Subscription(Protocol):
|
||||
CantHandleException: If the subscription cannot handle the topic_id.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Helper alias to represent the lambdas used to define subscriptions
|
||||
UnboundSubscription = Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
|
||||
@@ -3,7 +3,7 @@ The :mod:`autogen_core.components` module provides building blocks for creating
|
||||
"""
|
||||
|
||||
from ._closure_agent import ClosureAgent
|
||||
from ._default_subscription import DefaultSubscription
|
||||
from ._default_subscription import DefaultSubscription, default_subscription
|
||||
from ._default_topic import DefaultTopicId
|
||||
from ._image import Image
|
||||
from ._routed_agent import RoutedAgent, TypeRoutedAgent, event, message_handler, rpc
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
import inspect
|
||||
from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_type_hints
|
||||
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, cast, get_type_hints
|
||||
|
||||
from autogen_core.base import MessageContext
|
||||
|
||||
from ..base._agent import Agent
|
||||
from ..base._agent_id import AgentId
|
||||
from ..base._agent_instantiation import AgentInstantiationContext
|
||||
from ..base._agent_metadata import AgentMetadata
|
||||
from ..base._agent_runtime import AgentRuntime
|
||||
from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_serializers_for_type
|
||||
from ..base import (
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentMetadata,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
MessageContext,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from ..base._type_helpers import get_types
|
||||
from ..base.exceptions import CantHandleException
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_subscriptions_from_closure(
|
||||
def get_handled_types_from_closure(
|
||||
closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
|
||||
) -> Sequence[type]:
|
||||
args = inspect.getfullargspec(closure)[0]
|
||||
@@ -58,12 +62,8 @@ class ClosureAgent(Agent):
|
||||
self._runtime: AgentRuntime = runtime
|
||||
self._id: AgentId = id
|
||||
self._description = description
|
||||
subscription_types = get_subscriptions_from_closure(closure)
|
||||
# TODO fold this into runtime
|
||||
for message_type in subscription_types:
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(message_type))
|
||||
|
||||
self._subscriptions = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscription_types]
|
||||
handled_types = get_handled_types_from_closure(closure)
|
||||
self._expected_types = handled_types
|
||||
self._closure = closure
|
||||
|
||||
@property
|
||||
@@ -84,9 +84,9 @@ class ClosureAgent(Agent):
|
||||
return self._runtime
|
||||
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
|
||||
if MESSAGE_TYPE_REGISTRY.type_name(message) not in self._subscriptions:
|
||||
if type(message) not in self._expected_types:
|
||||
raise CantHandleException(
|
||||
f"Message type {type(message)} not in target types {self._subscriptions} of {self.id}"
|
||||
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}"
|
||||
)
|
||||
return await self._closure(self._runtime, self._id, message, ctx)
|
||||
|
||||
@@ -95,3 +95,39 @@ class ClosureAgent(Agent):
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise ValueError("load_state not implemented for ClosureAgent")
|
||||
|
||||
@classmethod
|
||||
async def register(
|
||||
cls,
|
||||
runtime: AgentRuntime,
|
||||
type: str,
|
||||
closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
|
||||
*,
|
||||
description: str = "",
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
|
||||
) -> AgentType:
|
||||
agent_type = AgentType(type)
|
||||
subscriptions_list: List[Subscription] = []
|
||||
if subscriptions is not None:
|
||||
with SubscriptionInstantiationContext.populate_context(agent_type):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list.extend(cast(List[Subscription], await subscriptions_list_result))
|
||||
else:
|
||||
subscriptions_list.extend(cast(List[Subscription], subscriptions_list_result))
|
||||
|
||||
agent_type = await runtime.register_factory(
|
||||
type=agent_type,
|
||||
agent_factory=lambda: ClosureAgent(description=description, closure=closure),
|
||||
expected_class=cls,
|
||||
)
|
||||
for subscription in subscriptions_list:
|
||||
await runtime.add_subscription(subscription)
|
||||
|
||||
handled_types = get_handled_types_from_closure(closure)
|
||||
for message_type in handled_types:
|
||||
# TODO: support custom serializers
|
||||
serializer = try_get_known_serializers_for_type(message_type)
|
||||
runtime.add_message_serializer(serializer)
|
||||
|
||||
return agent_type
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from autogen_core.base.exceptions import CantHandleException
|
||||
from typing import Callable, Type, TypeVar, overload
|
||||
|
||||
from ..base import SubscriptionInstantiationContext
|
||||
from ..base import BaseAgent, SubscriptionInstantiationContext, subscription_factory
|
||||
from ..base.exceptions import CantHandleException
|
||||
from ._type_subscription import TypeSubscription
|
||||
|
||||
|
||||
@@ -30,3 +31,27 @@ class DefaultSubscription(TypeSubscription):
|
||||
) from e
|
||||
|
||||
super().__init__(topic_type, agent_type)
|
||||
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
|
||||
@overload
|
||||
def default_subscription() -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def default_subscription(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: ...
|
||||
|
||||
|
||||
def default_subscription(
|
||||
cls: Type[BaseAgentType] | None = None,
|
||||
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]] | Type[BaseAgentType]:
|
||||
if cls is None:
|
||||
return subscription_factory(lambda: [DefaultSubscription()])
|
||||
else:
|
||||
return subscription_factory(lambda: [DefaultSubscription()])(cls)
|
||||
|
||||
|
||||
def type_subscription(topic_type: str) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
|
||||
return subscription_factory(lambda: [DefaultSubscription(topic_type=topic_type)])
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
DefaultDict,
|
||||
List,
|
||||
Literal,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
@@ -18,14 +19,15 @@ from typing import (
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from autogen_core.base import try_get_known_serializers_for_type
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext
|
||||
from ..base import BaseAgent, MessageContext, MessageSerializer, try_get_known_serializers_for_type
|
||||
from ..base._type_helpers import AnyType, get_types
|
||||
from ..base.exceptions import CantHandleException
|
||||
|
||||
logger = logging.getLogger("autogen_core")
|
||||
|
||||
AgentT = TypeVar("AgentT")
|
||||
ReceivesT = TypeVar("ReceivesT")
|
||||
ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
|
||||
@@ -36,23 +38,25 @@ ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here.
|
||||
# Revisit this later to see if we can remove the ignore.
|
||||
@runtime_checkable
|
||||
class MessageHandler(Protocol[ReceivesT, ProducesT]): # type: ignore
|
||||
class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]): # type: ignore
|
||||
target_types: Sequence[type]
|
||||
produces_types: Sequence[type]
|
||||
is_message_handler: Literal[True]
|
||||
router: Callable[[ReceivesT, MessageContext], bool]
|
||||
|
||||
async def __call__(self, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
|
||||
# agent_instance binds to self in the method
|
||||
@staticmethod
|
||||
async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
# TODO: Use a protocl for the outer function to check checked arg names
|
||||
# TODO: Use a protocol for the outer function to check checked arg names
|
||||
|
||||
|
||||
@overload
|
||||
def message_handler(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]: ...
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -62,8 +66,8 @@ def message_handler(
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@@ -74,22 +78,22 @@ def message_handler(
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def message_handler(
|
||||
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[ReceivesT, ProducesT]
|
||||
| MessageHandler[AgentT, ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for generic message handlers.
|
||||
|
||||
@@ -113,8 +117,8 @@ def message_handler(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]:
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
@@ -136,7 +140,7 @@ def message_handler(
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
@@ -153,7 +157,7 @@ def message_handler(
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
@@ -171,8 +175,8 @@ def message_handler(
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[ReceivesT, None]: ...
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -182,8 +186,8 @@ def event(
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[ReceivesT, None],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
@@ -194,22 +198,22 @@ def event(
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[ReceivesT, None],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
def event(
|
||||
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[ReceivesT, None],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[AgentT, ReceivesT, None],
|
||||
]
|
||||
| MessageHandler[ReceivesT, None]
|
||||
| MessageHandler[AgentT, ReceivesT, None]
|
||||
):
|
||||
"""Decorator for event message handlers.
|
||||
|
||||
@@ -233,8 +237,8 @@ def event(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[ReceivesT, None]:
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, None]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
@@ -255,7 +259,7 @@ def event(
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> None:
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
@@ -272,7 +276,7 @@ def event(
|
||||
|
||||
return None
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, None], wrapper)
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
@@ -291,8 +295,8 @@ def event(
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]: ...
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -302,8 +306,8 @@ def rpc(
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@@ -314,22 +318,22 @@ def rpc(
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def rpc(
|
||||
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[AgentT, ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[ReceivesT, ProducesT]
|
||||
| MessageHandler[AgentT, ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for RPC message handlers.
|
||||
|
||||
@@ -353,8 +357,8 @@ def rpc(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]:
|
||||
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
@@ -376,7 +380,7 @@ def rpc(
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
@@ -393,7 +397,7 @@ def rpc(
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
@@ -440,23 +444,15 @@ class RoutedAgent(BaseAgent):
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
# Self is already bound to the handlers
|
||||
self._handlers: Dict[
|
||||
self._handlers: DefaultDict[
|
||||
Type[Any],
|
||||
List[MessageHandler[Any, Any]],
|
||||
] = {}
|
||||
List[MessageHandler[RoutedAgent, Any, Any]],
|
||||
] = DefaultDict(list)
|
||||
|
||||
# Iterate over all attributes in alphabetical order and find message handlers.
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
message_handler = cast(MessageHandler[Any, Any], handler)
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers.setdefault(target_type, []).append(message_handler)
|
||||
|
||||
for message_type in self._handlers.keys():
|
||||
for serializer in try_get_known_serializers_for_type(message_type):
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(serializer)
|
||||
handlers = self._discover_handlers()
|
||||
for message_handler in handlers:
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type].append(message_handler)
|
||||
|
||||
super().__init__(description)
|
||||
|
||||
@@ -471,7 +467,7 @@ class RoutedAgent(BaseAgent):
|
||||
# Call the first handler whose router returns True and then return the result.
|
||||
for h in handlers:
|
||||
if h.router(message, ctx):
|
||||
return await h(message, ctx)
|
||||
return await h(self, message, ctx)
|
||||
return await self.on_unhandled_message(message, ctx) # type: ignore
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
@@ -479,6 +475,33 @@ class RoutedAgent(BaseAgent):
|
||||
The default implementation logs an info message."""
|
||||
logger.info(f"Unhandled message: {message}")
|
||||
|
||||
@classmethod
|
||||
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
|
||||
handlers = []
|
||||
for attr in dir(cls):
|
||||
if callable(getattr(cls, attr, None)):
|
||||
# Since we are getting it from the class, self is not bound
|
||||
handler = getattr(cls, attr)
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
handlers.append(cast(MessageHandler[Any, Any, Any], handler))
|
||||
return handlers
|
||||
|
||||
@classmethod
|
||||
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
|
||||
# TODO handle deduplication
|
||||
handlers = cls._discover_handlers()
|
||||
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
|
||||
types.extend(cls._extra_handles_types)
|
||||
for handler in handlers:
|
||||
for t in handler.target_types:
|
||||
# TODO: support different serializers
|
||||
serializers = try_get_known_serializers_for_type(t)
|
||||
if len(serializers) == 0:
|
||||
raise ValueError(f"No serializers found for type {t}.")
|
||||
|
||||
types.append((t, try_get_known_serializers_for_type(t)))
|
||||
return types
|
||||
|
||||
|
||||
# Deprecation warning for TypeRoutedAgent
|
||||
class TypeRoutedAgent(RoutedAgent):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import uuid
|
||||
from typing import TypeVar
|
||||
|
||||
from autogen_core.base.exceptions import CantHandleException
|
||||
|
||||
from ..base import AgentId, Subscription, TopicId
|
||||
from ..base import AgentId, BaseAgent, Subscription, TopicId
|
||||
from ..base.exceptions import CantHandleException
|
||||
|
||||
|
||||
class TypeSubscription(Subscription):
|
||||
@@ -51,3 +51,6 @@ class TypeSubscription(Subscription):
|
||||
raise CantHandleException("TopicId does not match the subscription")
|
||||
|
||||
return AgentId(type=self._agent_type, key=topic_id.source)
|
||||
|
||||
|
||||
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
|
||||
|
||||
@@ -5,6 +5,8 @@ import pytest
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, AgentRuntime, MessageContext, TopicId
|
||||
from autogen_core.components import ClosureAgent
|
||||
from autogen_core.components._default_subscription import DefaultSubscription
|
||||
from autogen_core.components._default_topic import DefaultTopicId
|
||||
from autogen_core.components._type_subscription import TypeSubscription
|
||||
|
||||
|
||||
@@ -23,14 +25,12 @@ async def test_register_receives_publish() -> None:
|
||||
key = id.key
|
||||
await queue.put((key, message.content))
|
||||
|
||||
await runtime.register("name", lambda: ClosureAgent("my_agent", log_message))
|
||||
await runtime.add_subscription(TypeSubscription("default", "name"))
|
||||
topic_id = TopicId("default", "default")
|
||||
await ClosureAgent.register(runtime, "name", log_message, subscriptions=lambda: [DefaultSubscription()])
|
||||
runtime.start()
|
||||
|
||||
await runtime.publish_message(Message("first message"), topic_id=topic_id)
|
||||
await runtime.publish_message(Message("second message"), topic_id=topic_id)
|
||||
await runtime.publish_message(Message("third message"), topic_id=topic_id)
|
||||
await runtime.publish_message(Message("first message"), topic_id=DefaultTopicId())
|
||||
await runtime.publish_message(Message("second message"), topic_id=DefaultTopicId())
|
||||
await runtime.publish_message(Message("third message"), topic_id=DefaultTopicId())
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from autogen_core.base import (
|
||||
JSON_DATA_CONTENT_TYPE,
|
||||
MessageSerializer,
|
||||
Serialization,
|
||||
SerializationRegistry,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
|
||||
@@ -41,7 +41,7 @@ class NestingPydanticDataclassMessage:
|
||||
|
||||
|
||||
def test_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde = SerializationRegistry()
|
||||
serde.add_serializer(try_get_known_serializers_for_type(PydanticMessage))
|
||||
|
||||
message = PydanticMessage(message="hello")
|
||||
@@ -54,7 +54,7 @@ def test_pydantic() -> None:
|
||||
|
||||
|
||||
def test_nested_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde = SerializationRegistry()
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticMessage))
|
||||
|
||||
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
|
||||
@@ -66,7 +66,7 @@ def test_nested_pydantic() -> None:
|
||||
|
||||
|
||||
def test_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde = SerializationRegistry()
|
||||
serde.add_serializer(try_get_known_serializers_for_type(DataclassMessage))
|
||||
|
||||
message = DataclassMessage(message="hello")
|
||||
@@ -78,7 +78,7 @@ def test_dataclass() -> None:
|
||||
|
||||
|
||||
def test_nesting_dataclass_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde = SerializationRegistry()
|
||||
with pytest.raises(ValueError):
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
|
||||
|
||||
@@ -102,14 +102,13 @@ def test_nesting_union_old_syntax_dataclass(
|
||||
|
||||
|
||||
def test_nesting_dataclass_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
|
||||
serde = SerializationRegistry()
|
||||
with pytest.raises(ValueError):
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
|
||||
|
||||
|
||||
def test_invalid_type() -> None:
|
||||
serde = Serialization()
|
||||
serde = SerializationRegistry()
|
||||
try:
|
||||
serde.add_serializer(try_get_known_serializers_for_type(str))
|
||||
except ValueError as e:
|
||||
@@ -117,7 +116,7 @@ def test_invalid_type() -> None:
|
||||
|
||||
|
||||
def test_custom_type() -> None:
|
||||
serde = Serialization()
|
||||
serde = SerializationRegistry()
|
||||
|
||||
class CustomStringTypeSerializer(MessageSerializer[str]):
|
||||
@property
|
||||
|
||||
@@ -3,7 +3,6 @@ import asyncio
|
||||
import pytest
|
||||
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
|
||||
from autogen_core.base import (
|
||||
MESSAGE_TYPE_REGISTRY,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
TopicId,
|
||||
@@ -20,8 +19,6 @@ async def test_agent_names_must_be_unique() -> None:
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
|
||||
@@ -53,9 +50,8 @@ async def test_register_receives_publish() -> None:
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
worker.start()
|
||||
|
||||
await worker.register("name", LoopbackAgent)
|
||||
@@ -87,8 +83,9 @@ async def test_register_receives_publish_cascade() -> None:
|
||||
host_address = "localhost:50053"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
runtime.start()
|
||||
|
||||
num_agents = 5
|
||||
@@ -125,8 +122,8 @@ async def test_default_subscription() -> None:
|
||||
host_address = "localhost:50054"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
@@ -155,8 +152,8 @@ async def test_non_default_default_subscription() -> None:
|
||||
host_address = "localhost:50055"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
|
||||
@@ -185,8 +182,8 @@ async def test_non_publish_to_other_source() -> None:
|
||||
host_address = "localhost:50056"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
runtime.start()
|
||||
|
||||
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
|
||||
|
||||
Reference in New Issue
Block a user