From 49b52db6ea27ffb6aa976a4d924b2da67b93b695 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 30 Dec 2024 15:00:42 -0500 Subject: [PATCH] Make `register_factory` a user facing API (#4854) * Make register_factory a user facing API * fix docstring * Update python/packages/autogen-core/src/autogen_core/_agent_runtime.py Co-authored-by: Eric Zhu * formatting --------- Co-authored-by: Eric Zhu --- .../src/autogen_core/_agent_runtime.py | 51 +++++++++++++++++-- .../_single_threaded_agent_runtime.py | 9 ++-- .../runtimes/grpc/_worker_runtime.py | 9 ++-- 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index d1483cbcf..5a3ebefbc 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -73,16 +73,59 @@ class AgentRuntime(Protocol): async def register_factory( self, - *, - type: AgentType, + type: str | AgentType, agent_factory: Callable[[], T | Awaitable[T]], - expected_class: type[T], + *, + expected_class: type[T] | None = None, ) -> AgentType: - """Register an agent factory with the runtime associated with a specific type. The type must be unique. + """Register an agent factory with the runtime associated with a specific type. The type must be unique. This API does not add any subscriptions. + + .. note:: + + This is a low level API and usually the agent class's `register` method should be used instead, as this also handles subscriptions automatically. + + Example: + + .. code-block:: python + + from dataclasses import dataclass + + from autogen_core import AgentRuntime, MessageContext, RoutedAgent, event + from autogen_core.models import UserMessage + + + @dataclass + class MyMessage: + content: str + + + class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My core agent") + + @event + async def handler(self, message: UserMessage, context: MessageContext) -> None: + print("Event received: ", message.content) + + + async def my_agent_factory(): + return MyAgent() + + + async def main() -> None: + runtime: AgentRuntime = ... # type: ignore + await runtime.register_factory("my_agent", lambda: MyAgent()) + + + import asyncio + + asyncio.run(main()) + Args: type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes. agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.AgentInstantiationContext` to access variables like the current runtime and agent ID. + expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. """ ... diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index a866ab5c2..5141ac9bc 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -558,11 +558,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def register_factory( self, - *, - type: AgentType, + type: str | AgentType, agent_factory: Callable[[], T | Awaitable[T]], - expected_class: type[T], + *, + expected_class: type[T] | None = None, ) -> AgentType: + if isinstance(type, str): + type = AgentType(type) + if type.type in self._agent_factories: raise ValueError(f"Agent with type {type} already exists.") diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 10fbd3e32..d331be766 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -671,11 +671,14 @@ class GrpcWorkerAgentRuntime(AgentRuntime): async def register_factory( self, - *, - type: AgentType, + type: str | AgentType, agent_factory: Callable[[], T | Awaitable[T]], - expected_class: type[T], + *, + expected_class: type[T] | None = None, ) -> AgentType: + if isinstance(type, str): + type = AgentType(type) + if type.type in self._agent_factories: raise ValueError(f"Agent with type {type} already exists.") if self._host_connection is None: