Initial impl of new register and subscriptions (#542)

* Initial impl of new register and subscriptions

* progress

* test fixes, main issue was unbound self in routed agent
This commit is contained in:
Jack Gerrits
2024-09-18 14:41:04 -04:00
committed by GitHub
parent a479a5e692
commit 093e261158
18 changed files with 423 additions and 156 deletions

View File

@@ -121,16 +121,16 @@
"import asyncio\n",
"\n",
"from autogen_core.application import WorkerAgentRuntime\n",
"from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, try_get_known_serializers_for_type\n",
"from autogen_core.base import AgentId, try_get_known_serializers_for_type\n",
"from autogen_core.components import DefaultSubscription\n",
"\n",
"MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MyMessage))\n",
"\n",
"worker1 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker1.add_message_serializer(try_get_known_serializers_for_type(MyMessage))\n",
"worker1.start()\n",
"await worker1.register(\"worker1\", lambda: MyAgent(\"worker1\"), lambda: [DefaultSubscription()])\n",
"\n",
"worker2 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker2.add_message_serializer(try_get_known_serializers_for_type(MyMessage))\n",
"worker2.start()\n",
"await worker2.register(\"worker2\", lambda: MyAgent(\"worker2\"), lambda: [DefaultSubscription()])\n",
"\n",

View File

@@ -4,8 +4,8 @@ from dataclasses import dataclass
from typing import Any, NoReturn
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext, try_get_known_serializers_for_type
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
@dataclass
@@ -33,6 +33,7 @@ class ReturnedFeedback:
content: str
@default_subscription
class ReceiveAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("Receive Agent")
@@ -49,6 +50,7 @@ class ReceiveAgent(RoutedAgent):
print(f"Unhandled message: {message}")
@default_subscription
class GreeterAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("Greeter Agent")
@@ -67,15 +69,10 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedGreeting))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(ReturnedFeedback))
runtime.start()
await runtime.register("receiver", ReceiveAgent, lambda: [DefaultSubscription()])
await runtime.register("greeter", GreeterAgent, lambda: [DefaultSubscription()])
await ReceiveAgent.register(runtime, "receiver", ReceiveAgent)
await GreeterAgent.register(runtime, "greeter", GreeterAgent)
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())

View File

@@ -5,7 +5,6 @@ from typing import Any, NoReturn
from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import (
MESSAGE_TYPE_REGISTRY,
AgentId,
AgentInstantiationContext,
MessageContext,
@@ -61,9 +60,6 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Greeting))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(AskToGreet))
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(Feedback))
runtime.start()
await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])

View File

@@ -12,6 +12,9 @@ from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from opentelemetry.trace import TracerProvider
from typing_extensions import deprecated
from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
from ..base import (
Agent,
@@ -163,6 +166,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
self._serialization_registry = SerializationRegistry()
@property
def unprocessed_messages(
@@ -522,6 +526,9 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
(await self._get_agent(agent)).load_state(state)
@deprecated(
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
)
async def register(
self,
type: str,
@@ -550,6 +557,29 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._agent_factories[type] = agent_factory
return AgentType(type)
async def register_factory(
self,
*,
type: AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
agent_instance = await maybe_agent_instance
else:
agent_instance = maybe_agent_instance
if type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance
self._agent_factories[type.type] = factory_wrapper
return type
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
@@ -616,3 +646,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
lazy=lazy,
instance_getter=self._get_agent,
)
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)

View File

@@ -30,12 +30,12 @@ from typing import (
import grpc
from grpc.aio import StreamStreamCall
from opentelemetry.trace import NoOpTracerProvider, TracerProvider
from typing_extensions import Self
from typing_extensions import Self, deprecated
from autogen_core.base import JSON_DATA_CONTENT_TYPE
from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
from ..base import (
MESSAGE_TYPE_REGISTRY,
Agent,
AgentId,
AgentInstantiationContext,
@@ -63,6 +63,8 @@ event_logger = logging.getLogger("autogen_core.events")
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
type_func_alias = type
class QueueAsyncIterable(AsyncIterator[Any], AsyncIterable[Any]):
def __init__(self, queue: asyncio.Queue[Any]) -> None:
@@ -166,6 +168,7 @@ class WorkerAgentRuntime(AgentRuntime):
self._host_connection: HostConnection | None = None
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._serialization_registry = SerializationRegistry()
def start(self) -> None:
"""Start the runtime in a background task."""
@@ -286,7 +289,7 @@ class WorkerAgentRuntime(AgentRuntime):
raise ValueError("Runtime must be running when sending message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
data_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", recipient, parent=None, extraAttributes={"message_type": data_type, "message_size": len(message)}
):
@@ -297,7 +300,7 @@ class WorkerAgentRuntime(AgentRuntime):
request_id = self._next_request_id
request_id_str = str(request_id)
self._pending_requests[request_id_str] = future
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
serialized_message = self._serialization_registry.serialize(
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
telemetry_metadata = get_telemetry_grpc_metadata()
@@ -334,11 +337,11 @@ class WorkerAgentRuntime(AgentRuntime):
raise ValueError("Runtime must be running when publishing message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
message_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
):
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
serialized_message = self._serialization_registry.serialize(
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
telemetry_metadata = get_telemetry_grpc_metadata()
@@ -387,7 +390,7 @@ class WorkerAgentRuntime(AgentRuntime):
logging.info(f"Processing request from unknown source to {recipient}")
# Deserialize the message.
message = MESSAGE_TYPE_REGISTRY.deserialize(
message = self._serialization_registry.deserialize(
request.payload.data,
type_name=request.payload.data_type,
data_content_type=request.payload.data_content_type,
@@ -426,8 +429,8 @@ class WorkerAgentRuntime(AgentRuntime):
return
# Serialize the result.
result_type = MESSAGE_TYPE_REGISTRY.type_name(result)
serialized_result = MESSAGE_TYPE_REGISTRY.serialize(
result_type = self._serialization_registry.type_name(result)
serialized_result = self._serialization_registry.serialize(
result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
@@ -456,7 +459,7 @@ class WorkerAgentRuntime(AgentRuntime):
extraAttributes={"message_type": response.payload.data_type},
):
# Deserialize the result.
result = MESSAGE_TYPE_REGISTRY.deserialize(
result = self._serialization_registry.deserialize(
response.payload.data,
type_name=response.payload.data_type,
data_content_type=response.payload.data_content_type,
@@ -469,7 +472,7 @@ class WorkerAgentRuntime(AgentRuntime):
future.set_result(result)
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
message = MESSAGE_TYPE_REGISTRY.deserialize(
message = self._serialization_registry.deserialize(
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
)
sender: AgentId | None = None
@@ -509,6 +512,9 @@ class WorkerAgentRuntime(AgentRuntime):
except BaseException as e:
logger.error("Error handling event", exc_info=e)
@deprecated(
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
)
async def register(
self,
type: str,
@@ -542,6 +548,29 @@ class WorkerAgentRuntime(AgentRuntime):
return AgentType(type)
async def register_factory(
self,
*,
type: AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
agent_instance = await maybe_agent_instance
else:
agent_instance = maybe_agent_instance
if type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance
self._agent_factories[type.type] = factory_wrapper
return type
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
@@ -622,3 +651,6 @@ class WorkerAgentRuntime(AgentRuntime):
lazy=lazy,
instance_getter=self._get_agent,
)
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)

View File

@@ -10,15 +10,14 @@ from ._agent_props import AgentChildren
from ._agent_proxy import AgentProxy
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._base_agent import BaseAgent, subscription_factory
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._serialization import (
JSON_DATA_CONTENT_TYPE,
MESSAGE_TYPE_REGISTRY,
MessageSerializer,
Serialization,
SerializationRegistry,
UnknownPayload,
try_get_known_serializers_for_type,
)
@@ -36,11 +35,10 @@ __all__ = [
"CancellationToken",
"AgentChildren",
"AgentInstantiationContext",
"MESSAGE_TYPE_REGISTRY",
"TopicId",
"Subscription",
"MessageContext",
"Serialization",
"SerializationRegistry",
"AgentType",
"SubscriptionInstantiationContext",
"MessageHandlerContext",
@@ -48,4 +46,5 @@ __all__ = [
"MessageSerializer",
"try_get_known_serializers_for_type",
"UnknownPayload",
"subscription_factory",
]

View File

@@ -1,12 +1,16 @@
from __future__ import annotations
from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
from collections.abc import Sequence
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Type, TypeVar, overload, runtime_checkable
from typing_extensions import deprecated
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._serialization import MessageSerializer
from ._subscription import Subscription
from ._topic import TopicId
@@ -67,6 +71,9 @@ class AgentRuntime(Protocol):
"""
...
@deprecated(
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
)
async def register(
self,
type: str,
@@ -82,6 +89,34 @@ class AgentRuntime(Protocol):
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.base.AgentInstantiationContext` to access variables like the current runtime and agent ID.
subscriptions (Callable[[], list[Subscription]] | list[Subscription] | None, optional): The subscriptions that the agent should be subscribed to. Defaults to None.
Example:
.. code-block:: python
runtime.register(
"chat_agent",
lambda: ChatCompletionAgent(
description="A generic chat agent.",
system_messages=[SystemMessage("You are a helpful assistant")],
model_client=OpenAIChatCompletionClient(model="gpt-4o"),
memory=BufferedChatMemory(buffer_size=10),
),
)
"""
...
async def register_factory(
self,
*,
type: AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
Args:
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.base.AgentInstantiationContext` to access variables like the current runtime and agent ID.
Example:
.. code-block:: python
@@ -97,7 +132,6 @@ class AgentRuntime(Protocol):
)
"""
...
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
@@ -199,3 +233,13 @@ class AgentRuntime(Protocol):
LookupError: If the subscription does not exist
"""
...
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
"""Add a new message serialization serializer to the runtime
Note: This will deduplicate serializers based on the type_name and data_content_type properties
Args:
serializer (MessageSerializer[Any] | Sequence[MessageSerializer[Any]]): The serializer/s to add
"""
...

View File

@@ -1,18 +1,71 @@
from __future__ import annotations
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import Any, Mapping
from collections.abc import Sequence
from re import S
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, overload
from typing_extensions import Self
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
from ._subscription import UnboundSubscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
T = TypeVar("T", bound=Agent)
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
# Decorator for adding an unbound subscription to an agent
def subscription_factory(subscription: UnboundSubscription) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
cls._unbound_subscriptions_list.append(subscription)
return cls
return decorator
def handles(
type: Type[Any], serializer: MessageSerializer[Any] | List[MessageSerializer[Any]] | None = None
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]:
if serializer is None:
serializer_list = try_get_known_serializers_for_type(type)
else:
serializer_list = [serializer] if not isinstance(serializer, Sequence) else serializer
if len(serializer_list) == 0:
raise ValueError(f"No serializers found for type {type}. Please provide an explicit serializer.")
cls._extra_handles_types.append((type, serializer_list))
return cls
return decorator
class BaseAgent(ABC, Agent):
_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = []
_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = []
@classmethod
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
return cls._extra_handles_types
@classmethod
def _unbound_subscriptions(cls) -> List[UnboundSubscription]:
return cls._unbound_subscriptions_list
@property
def metadata(self) -> AgentMetadata:
assert self._id is not None
@@ -82,3 +135,29 @@ class BaseAgent(ABC, Agent):
def load_state(self, state: Mapping[str, Any]) -> None:
warnings.warn("load_state not implemented", stacklevel=2)
pass
@classmethod
async def register(
cls, runtime: AgentRuntime, type: str, factory: Callable[[], Self | Awaitable[Self]]
) -> AgentType:
agent_type = AgentType(type)
with SubscriptionInstantiationContext.populate_context(agent_type):
subscriptions = []
for unbound_subscription in cls._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
subscriptions.extend(subscriptions_list)
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
for subscription in subscriptions:
await runtime.add_subscription(subscription)
# TODO: deduplication
for _message_type, serializer in cls._handles_types():
runtime.add_message_serializer(serializer)
return agent_type

View File

@@ -1,8 +1,9 @@
import json
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, get_args, get_origin, runtime_checkable
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable
from pydantic import BaseModel
from typing_extensions import deprecated
from autogen_core.base._type_helpers import is_union
@@ -171,13 +172,13 @@ def try_get_known_serializers_for_type(cls: type[Any]) -> list[MessageSerializer
return serializers
class Serialization:
class SerializationRegistry:
def __init__(self) -> None:
# type_name, data_content_type -> serializer
self._serializers: dict[tuple[str, str], MessageSerializer[Any]] = {}
def add_serializer(self, serializer: MessageSerializer[Any] | List[MessageSerializer[Any]]) -> None:
if isinstance(serializer, list):
def add_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
if isinstance(serializer, Sequence):
for c in serializer:
self.add_serializer(c)
return
@@ -203,6 +204,3 @@ class Serialization:
def type_name(self, message: Any) -> str:
return _type_name(message)
MESSAGE_TYPE_REGISTRY = Serialization()

View File

@@ -1,7 +1,9 @@
from typing import Protocol, runtime_checkable
from __future__ import annotations
from autogen_core.base import AgentId
from typing import Any, Awaitable, Callable, Protocol, runtime_checkable
from ._agent_id import AgentId
from ._agent_type import AgentType
from ._topic import TopicId
@@ -58,3 +60,7 @@ class Subscription(Protocol):
CantHandleException: If the subscription cannot handle the topic_id.
"""
...
# Helper alias to represent the lambdas used to define subscriptions
UnboundSubscription = Callable[[], list[Subscription] | Awaitable[list[Subscription]]]

View File

@@ -3,7 +3,7 @@ The :mod:`autogen_core.components` module provides building blocks for creating
"""
from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription
from ._default_subscription import DefaultSubscription, default_subscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._routed_agent import RoutedAgent, TypeRoutedAgent, event, message_handler, rpc

View File

@@ -1,21 +1,25 @@
import inspect
from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_type_hints
from typing import Any, Awaitable, Callable, List, Mapping, Sequence, TypeVar, cast, get_type_hints
from autogen_core.base import MessageContext
from ..base._agent import Agent
from ..base._agent_id import AgentId
from ..base._agent_instantiation import AgentInstantiationContext
from ..base._agent_metadata import AgentMetadata
from ..base._agent_runtime import AgentRuntime
from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_serializers_for_type
from ..base import (
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentRuntime,
AgentType,
MessageContext,
Subscription,
SubscriptionInstantiationContext,
try_get_known_serializers_for_type,
)
from ..base._type_helpers import get_types
from ..base.exceptions import CantHandleException
T = TypeVar("T")
def get_subscriptions_from_closure(
def get_handled_types_from_closure(
closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
) -> Sequence[type]:
args = inspect.getfullargspec(closure)[0]
@@ -58,12 +62,8 @@ class ClosureAgent(Agent):
self._runtime: AgentRuntime = runtime
self._id: AgentId = id
self._description = description
subscription_types = get_subscriptions_from_closure(closure)
# TODO fold this into runtime
for message_type in subscription_types:
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(message_type))
self._subscriptions = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscription_types]
handled_types = get_handled_types_from_closure(closure)
self._expected_types = handled_types
self._closure = closure
@property
@@ -84,9 +84,9 @@ class ClosureAgent(Agent):
return self._runtime
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
if MESSAGE_TYPE_REGISTRY.type_name(message) not in self._subscriptions:
if type(message) not in self._expected_types:
raise CantHandleException(
f"Message type {type(message)} not in target types {self._subscriptions} of {self.id}"
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}"
)
return await self._closure(self._runtime, self._id, message, ctx)
@@ -95,3 +95,39 @@ class ClosureAgent(Agent):
def load_state(self, state: Mapping[str, Any]) -> None:
raise ValueError("load_state not implemented for ClosureAgent")
@classmethod
async def register(
cls,
runtime: AgentRuntime,
type: str,
closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]],
*,
description: str = "",
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
) -> AgentType:
agent_type = AgentType(type)
subscriptions_list: List[Subscription] = []
if subscriptions is not None:
with SubscriptionInstantiationContext.populate_context(agent_type):
subscriptions_list_result = subscriptions()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list.extend(cast(List[Subscription], await subscriptions_list_result))
else:
subscriptions_list.extend(cast(List[Subscription], subscriptions_list_result))
agent_type = await runtime.register_factory(
type=agent_type,
agent_factory=lambda: ClosureAgent(description=description, closure=closure),
expected_class=cls,
)
for subscription in subscriptions_list:
await runtime.add_subscription(subscription)
handled_types = get_handled_types_from_closure(closure)
for message_type in handled_types:
# TODO: support custom serializers
serializer = try_get_known_serializers_for_type(message_type)
runtime.add_message_serializer(serializer)
return agent_type

View File

@@ -1,6 +1,7 @@
from autogen_core.base.exceptions import CantHandleException
from typing import Callable, Type, TypeVar, overload
from ..base import SubscriptionInstantiationContext
from ..base import BaseAgent, SubscriptionInstantiationContext, subscription_factory
from ..base.exceptions import CantHandleException
from ._type_subscription import TypeSubscription
@@ -30,3 +31,27 @@ class DefaultSubscription(TypeSubscription):
) from e
super().__init__(topic_type, agent_type)
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")
@overload
def default_subscription() -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]: ...
@overload
def default_subscription(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: ...
def default_subscription(
cls: Type[BaseAgentType] | None = None,
) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]] | Type[BaseAgentType]:
if cls is None:
return subscription_factory(lambda: [DefaultSubscription()])
else:
return subscription_factory(lambda: [DefaultSubscription()])(cls)
def type_subscription(topic_type: str) -> Callable[[Type[BaseAgentType]], Type[BaseAgentType]]:
return subscription_factory(lambda: [DefaultSubscription(topic_type=topic_type)])

View File

@@ -5,11 +5,12 @@ from typing import (
Any,
Callable,
Coroutine,
Dict,
DefaultDict,
List,
Literal,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
cast,
@@ -18,14 +19,15 @@ from typing import (
runtime_checkable,
)
from autogen_core.base import try_get_known_serializers_for_type
from typing_extensions import Self
from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext
from ..base import BaseAgent, MessageContext, MessageSerializer, try_get_known_serializers_for_type
from ..base._type_helpers import AnyType, get_types
from ..base.exceptions import CantHandleException
logger = logging.getLogger("autogen_core")
AgentT = TypeVar("AgentT")
ReceivesT = TypeVar("ReceivesT")
ProducesT = TypeVar("ProducesT", covariant=True)
@@ -36,23 +38,25 @@ ProducesT = TypeVar("ProducesT", covariant=True)
# Pyright and mypy disagree on the variance of ReceivesT. Mypy thinks it should be contravariant here.
# Revisit this later to see if we can remove the ignore.
@runtime_checkable
class MessageHandler(Protocol[ReceivesT, ProducesT]): # type: ignore
class MessageHandler(Protocol[AgentT, ReceivesT, ProducesT]): # type: ignore
target_types: Sequence[type]
produces_types: Sequence[type]
is_message_handler: Literal[True]
router: Callable[[ReceivesT, MessageContext], bool]
async def __call__(self, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
# agent_instance binds to self in the method
@staticmethod
async def __call__(agent_instance: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: ...
# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocl for the outer function to check checked arg names
# TODO: Use a protocol for the outer function to check checked arg names
@overload
def message_handler(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]: ...
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
@overload
@@ -62,8 +66,8 @@ def message_handler(
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
@@ -74,22 +78,22 @@ def message_handler(
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
def message_handler(
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]
| MessageHandler[ReceivesT, ProducesT]
| MessageHandler[AgentT, ReceivesT, ProducesT]
):
"""Decorator for generic message handlers.
@@ -113,8 +117,8 @@ def message_handler(
"""
def decorator(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
@@ -136,7 +140,7 @@ def message_handler(
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT:
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
@@ -153,7 +157,7 @@ def message_handler(
return return_value
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
@@ -171,8 +175,8 @@ def message_handler(
@overload
def event(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[ReceivesT, None]: ...
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[AgentT, ReceivesT, None]: ...
@overload
@@ -182,8 +186,8 @@ def event(
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[ReceivesT, None],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[AgentT, ReceivesT, None],
]: ...
@@ -194,22 +198,22 @@ def event(
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[ReceivesT, None],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[AgentT, ReceivesT, None],
]: ...
def event(
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[ReceivesT, None],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[AgentT, ReceivesT, None],
]
| MessageHandler[ReceivesT, None]
| MessageHandler[AgentT, ReceivesT, None]
):
"""Decorator for event message handlers.
@@ -233,8 +237,8 @@ def event(
"""
def decorator(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[ReceivesT, None]:
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[AgentT, ReceivesT, None]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
@@ -255,7 +259,7 @@ def event(
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> None:
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
@@ -272,7 +276,7 @@ def event(
return None
wrapper_handler = cast(MessageHandler[ReceivesT, None], wrapper)
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, None], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
@@ -291,8 +295,8 @@ def event(
@overload
def rpc(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]: ...
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]: ...
@overload
@@ -302,8 +306,8 @@ def rpc(
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
@@ -314,22 +318,22 @@ def rpc(
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]: ...
def rpc(
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
func: None | Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
[Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[AgentT, ReceivesT, ProducesT],
]
| MessageHandler[ReceivesT, ProducesT]
| MessageHandler[AgentT, ReceivesT, ProducesT]
):
"""Decorator for RPC message handlers.
@@ -353,8 +357,8 @@ def rpc(
"""
def decorator(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
func: Callable[[AgentT, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[AgentT, ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
@@ -376,7 +380,7 @@ def rpc(
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT:
async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
@@ -393,7 +397,7 @@ def rpc(
return return_value
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
@@ -440,23 +444,15 @@ class RoutedAgent(BaseAgent):
def __init__(self, description: str) -> None:
# Self is already bound to the handlers
self._handlers: Dict[
self._handlers: DefaultDict[
Type[Any],
List[MessageHandler[Any, Any]],
] = {}
List[MessageHandler[RoutedAgent, Any, Any]],
] = DefaultDict(list)
# Iterate over all attributes in alphabetical order and find message handlers.
for attr in dir(self):
if callable(getattr(self, attr, None)):
handler = getattr(self, attr)
if hasattr(handler, "is_message_handler"):
message_handler = cast(MessageHandler[Any, Any], handler)
for target_type in message_handler.target_types:
self._handlers.setdefault(target_type, []).append(message_handler)
for message_type in self._handlers.keys():
for serializer in try_get_known_serializers_for_type(message_type):
MESSAGE_TYPE_REGISTRY.add_serializer(serializer)
handlers = self._discover_handlers()
for message_handler in handlers:
for target_type in message_handler.target_types:
self._handlers[target_type].append(message_handler)
super().__init__(description)
@@ -471,7 +467,7 @@ class RoutedAgent(BaseAgent):
# Call the first handler whose router returns True and then return the result.
for h in handlers:
if h.router(message, ctx):
return await h(message, ctx)
return await h(self, message, ctx)
return await self.on_unhandled_message(message, ctx) # type: ignore
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
@@ -479,6 +475,33 @@ class RoutedAgent(BaseAgent):
The default implementation logs an info message."""
logger.info(f"Unhandled message: {message}")
@classmethod
def _discover_handlers(cls) -> Sequence[MessageHandler[Any, Any, Any]]:
handlers = []
for attr in dir(cls):
if callable(getattr(cls, attr, None)):
# Since we are getting it from the class, self is not bound
handler = getattr(cls, attr)
if hasattr(handler, "is_message_handler"):
handlers.append(cast(MessageHandler[Any, Any, Any], handler))
return handlers
@classmethod
def _handles_types(cls) -> List[Tuple[Type[Any], List[MessageSerializer[Any]]]]:
# TODO handle deduplication
handlers = cls._discover_handlers()
types: List[Tuple[Type[Any], List[MessageSerializer[Any]]]] = []
types.extend(cls._extra_handles_types)
for handler in handlers:
for t in handler.target_types:
# TODO: support different serializers
serializers = try_get_known_serializers_for_type(t)
if len(serializers) == 0:
raise ValueError(f"No serializers found for type {t}.")
types.append((t, try_get_known_serializers_for_type(t)))
return types
# Deprecation warning for TypeRoutedAgent
class TypeRoutedAgent(RoutedAgent):

View File

@@ -1,8 +1,8 @@
import uuid
from typing import TypeVar
from autogen_core.base.exceptions import CantHandleException
from ..base import AgentId, Subscription, TopicId
from ..base import AgentId, BaseAgent, Subscription, TopicId
from ..base.exceptions import CantHandleException
class TypeSubscription(Subscription):
@@ -51,3 +51,6 @@ class TypeSubscription(Subscription):
raise CantHandleException("TopicId does not match the subscription")
return AgentId(type=self._agent_type, key=topic_id.source)
BaseAgentType = TypeVar("BaseAgentType", bound="BaseAgent")

View File

@@ -5,6 +5,8 @@ import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentRuntime, MessageContext, TopicId
from autogen_core.components import ClosureAgent
from autogen_core.components._default_subscription import DefaultSubscription
from autogen_core.components._default_topic import DefaultTopicId
from autogen_core.components._type_subscription import TypeSubscription
@@ -23,14 +25,12 @@ async def test_register_receives_publish() -> None:
key = id.key
await queue.put((key, message.content))
await runtime.register("name", lambda: ClosureAgent("my_agent", log_message))
await runtime.add_subscription(TypeSubscription("default", "name"))
topic_id = TopicId("default", "default")
await ClosureAgent.register(runtime, "name", log_message, subscriptions=lambda: [DefaultSubscription()])
runtime.start()
await runtime.publish_message(Message("first message"), topic_id=topic_id)
await runtime.publish_message(Message("second message"), topic_id=topic_id)
await runtime.publish_message(Message("third message"), topic_id=topic_id)
await runtime.publish_message(Message("first message"), topic_id=DefaultTopicId())
await runtime.publish_message(Message("second message"), topic_id=DefaultTopicId())
await runtime.publish_message(Message("third message"), topic_id=DefaultTopicId())
await runtime.stop_when_idle()

View File

@@ -5,7 +5,7 @@ import pytest
from autogen_core.base import (
JSON_DATA_CONTENT_TYPE,
MessageSerializer,
Serialization,
SerializationRegistry,
try_get_known_serializers_for_type,
)
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
@@ -41,7 +41,7 @@ class NestingPydanticDataclassMessage:
def test_pydantic() -> None:
serde = Serialization()
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(PydanticMessage))
message = PydanticMessage(message="hello")
@@ -54,7 +54,7 @@ def test_pydantic() -> None:
def test_nested_pydantic() -> None:
serde = Serialization()
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticMessage))
message = NestingPydanticMessage(message="hello", nested=PydanticMessage(message="world"))
@@ -66,7 +66,7 @@ def test_nested_pydantic() -> None:
def test_dataclass() -> None:
serde = Serialization()
serde = SerializationRegistry()
serde.add_serializer(try_get_known_serializers_for_type(DataclassMessage))
message = DataclassMessage(message="hello")
@@ -78,7 +78,7 @@ def test_dataclass() -> None:
def test_nesting_dataclass_dataclass() -> None:
serde = Serialization()
serde = SerializationRegistry()
with pytest.raises(ValueError):
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
@@ -102,14 +102,13 @@ def test_nesting_union_old_syntax_dataclass(
def test_nesting_dataclass_pydantic() -> None:
serde = Serialization()
serde = SerializationRegistry()
with pytest.raises(ValueError):
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
def test_invalid_type() -> None:
serde = Serialization()
serde = SerializationRegistry()
try:
serde.add_serializer(try_get_known_serializers_for_type(str))
except ValueError as e:
@@ -117,7 +116,7 @@ def test_invalid_type() -> None:
def test_custom_type() -> None:
serde = Serialization()
serde = SerializationRegistry()
class CustomStringTypeSerializer(MessageSerializer[str]):
@property

View File

@@ -3,7 +3,6 @@ import asyncio
import pytest
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
from autogen_core.base import (
MESSAGE_TYPE_REGISTRY,
AgentId,
AgentInstantiationContext,
TopicId,
@@ -20,8 +19,6 @@ async def test_agent_names_must_be_unique() -> None:
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
worker = WorkerAgentRuntime(host_address=host_address)
worker.start()
@@ -53,9 +50,8 @@ async def test_register_receives_publish() -> None:
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
worker = WorkerAgentRuntime(host_address=host_address)
worker.add_message_serializer(try_get_known_serializers_for_type(MessageType))
worker.start()
await worker.register("name", LoopbackAgent)
@@ -87,8 +83,9 @@ async def test_register_receives_publish_cascade() -> None:
host_address = "localhost:50053"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
runtime.start()
num_agents = 5
@@ -125,8 +122,8 @@ async def test_default_subscription() -> None:
host_address = "localhost:50054"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.start()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
@@ -155,8 +152,8 @@ async def test_non_default_default_subscription() -> None:
host_address = "localhost:50055"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.start()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
@@ -185,8 +182,8 @@ async def test_non_publish_to_other_source() -> None:
host_address = "localhost:50056"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
MESSAGE_TYPE_REGISTRY.add_serializer(try_get_known_serializers_for_type(MessageType))
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.start()
await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])