mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Implementation for @rpc and @event decorators (#504)
* Implementation for `@rpc` and `@event` decorators * update
This commit is contained in:
@@ -6,7 +6,7 @@ from ._closure_agent import ClosureAgent
|
||||
from ._default_subscription import DefaultSubscription
|
||||
from ._default_topic import DefaultTopicId
|
||||
from ._image import Image
|
||||
from ._routed_agent import RoutedAgent, TypeRoutedAgent, message_handler
|
||||
from ._routed_agent import RoutedAgent, TypeRoutedAgent, message_handler, event, rpc
|
||||
from ._type_subscription import TypeSubscription
|
||||
from ._types import FunctionCall
|
||||
|
||||
@@ -16,6 +16,8 @@ __all__ = [
|
||||
"TypeRoutedAgent",
|
||||
"ClosureAgent",
|
||||
"message_handler",
|
||||
"event",
|
||||
"rpc",
|
||||
"FunctionCall",
|
||||
"TypeSubscription",
|
||||
"DefaultSubscription",
|
||||
|
||||
@@ -91,9 +91,9 @@ def message_handler(
|
||||
]
|
||||
| MessageHandler[ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for message handlers.
|
||||
"""Decorator for generic message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle messages.
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle both event and RPC messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
@@ -169,19 +169,259 @@ def message_handler(
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[ReceivesT, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def event(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[ReceivesT, None],
|
||||
]: ...
|
||||
|
||||
|
||||
def event(
|
||||
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
|
||||
MessageHandler[ReceivesT, None],
|
||||
]
|
||||
| MessageHandler[ReceivesT, None]
|
||||
):
|
||||
"""Decorator for event message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle event messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The event message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`autogen_core.base.MessageContext` object.
|
||||
- The method must return `None`.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
|
||||
) -> MessageHandler[ReceivesT, None]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found. Please provide a type hint for the message parameter.")
|
||||
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> None:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx) # type: ignore
|
||||
|
||||
if return_value is not None:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} is not None.")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.")
|
||||
|
||||
return None
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, None], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
# Wrap the match function with a check on the is_rpc flag.
|
||||
wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: None = None,
|
||||
*,
|
||||
match: None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rpc(
|
||||
func: None = None,
|
||||
*,
|
||||
match: Callable[[ReceivesT, MessageContext], bool],
|
||||
strict: bool = ...,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]: ...
|
||||
|
||||
|
||||
def rpc(
|
||||
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]
|
||||
| MessageHandler[ReceivesT, ProducesT]
|
||||
):
|
||||
"""Decorator for RPC message handlers.
|
||||
|
||||
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle RPC messages.
|
||||
These methods must have a specific signature that needs to be followed for it to be valid:
|
||||
|
||||
- The method must be an `async` method.
|
||||
- The method must be decorated with the `@message_handler` decorator.
|
||||
- The method must have exactly 3 arguments:
|
||||
1. `self`
|
||||
2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle.
|
||||
3. `ctx`: A :class:`autogen_core.base.MessageContext` object.
|
||||
- The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything.
|
||||
|
||||
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated.
|
||||
strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead.
|
||||
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
# print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT:
|
||||
if type(message) not in target_types:
|
||||
if strict:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, ctx)
|
||||
|
||||
if AnyType not in return_types and type(return_value) not in return_types:
|
||||
if strict:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
else:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True)
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
if func is None and not callable(func):
|
||||
return decorator
|
||||
elif callable(func):
|
||||
return decorator(func)
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
class RoutedAgent(BaseAgent):
|
||||
"""A base class for agents that route messages to handlers based on the type of the message
|
||||
and optional matching functions.
|
||||
|
||||
To create a routed agent, subclass this class and add message handlers as methods decorated with
|
||||
the :func:`message_handler` decorator.
|
||||
either :func:`event` or :func:`rpc` decorator.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import RoutedAgent, message_handler
|
||||
from autogen_core.components import RoutedAgent, event, rpc
|
||||
# Assume Message, MessageWithContent, and Response are defined elsewhere.
|
||||
|
||||
|
||||
@@ -189,12 +429,12 @@ class RoutedAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__("MyAgent")
|
||||
|
||||
@message_handler
|
||||
async def handle_message(self, message: Message, ctx: MessageContext) -> Response:
|
||||
return Response()
|
||||
@event
|
||||
async def handle_event_message(self, message: Message, ctx: MessageContext) -> None:
|
||||
self.publish_message(MessageWithContent("event handled"), ctx.topic_id)
|
||||
|
||||
@message_handler(match=lambda message, ctx: message.content == "special")
|
||||
async def handle_special_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
|
||||
@rpc(match=lambda message, ctx: message.content == "special")
|
||||
async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
|
||||
return Response()
|
||||
"""
|
||||
|
||||
@@ -223,7 +463,7 @@ class RoutedAgent(BaseAgent):
|
||||
async def on_message(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||
"""Handle a message by routing it to the appropriate message handler.
|
||||
Do not override this method in subclasses. Instead, add message handlers as methods decorated with
|
||||
the :func:`message_handler` decorator."""
|
||||
either the :func:`event` or :func:`rpc` decorator."""
|
||||
key_type: Type[Any] = type(message) # type: ignore
|
||||
handlers = self._handlers.get(key_type) # type: ignore
|
||||
if handlers is not None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Callable, cast
|
||||
import pytest
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base import AgentId, MessageContext, TopicId
|
||||
from autogen_core.components import RoutedAgent, TypeSubscription, message_handler
|
||||
from autogen_core.components import RoutedAgent, TypeSubscription, message_handler, event, rpc
|
||||
from test_utils import LoopbackAgent
|
||||
|
||||
|
||||
@@ -115,3 +115,95 @@ async def test_routed_agent_message_matching() -> None:
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
|
||||
assert agent.handler_one_called is True
|
||||
assert agent.handler_two_called is True
|
||||
|
||||
|
||||
class EventAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("An event agent.")
|
||||
self.num_calls = [0, 0]
|
||||
|
||||
@event(match=lambda msg, ctx: msg.value == "one") # type: ignore
|
||||
async def on_event_one(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
self.num_calls[0] += 1
|
||||
|
||||
@event(match=lambda msg, ctx: msg.value == "two") # type: ignore
|
||||
async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
self.num_calls[1] += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register("counter", EventAgent, lambda: [TypeSubscription("default", "counter")])
|
||||
agent_id = AgentId(type="counter", key="default")
|
||||
|
||||
# Send a broadcast message.
|
||||
runtime.start()
|
||||
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
assert agent.num_calls[1] == 0
|
||||
|
||||
# Send another broadcast message.
|
||||
runtime.start()
|
||||
await runtime.publish_message(TestMessage("two"), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
assert agent.num_calls[1] == 1
|
||||
|
||||
# Send an RPC message, expect no change.
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("one"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
assert agent.num_calls[1] == 1
|
||||
|
||||
|
||||
class RPCAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("An RPC agent.")
|
||||
self.num_calls = [0, 0]
|
||||
|
||||
@rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore
|
||||
async def on_rpc_one(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
|
||||
self.num_calls[0] += 1
|
||||
return message
|
||||
|
||||
@rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore
|
||||
async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
|
||||
self.num_calls[1] += 1
|
||||
return message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rpc() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
await runtime.register("counter", RPCAgent, lambda: [TypeSubscription("default", "counter")])
|
||||
agent_id = AgentId(type="counter", key="default")
|
||||
|
||||
# Send an RPC message.
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("one"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
assert agent.num_calls[1] == 0
|
||||
|
||||
# Send another RPC message.
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("two"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
assert agent.num_calls[1] == 1
|
||||
|
||||
# Send a broadcast message, expect no change.
|
||||
runtime.start()
|
||||
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
assert agent.num_calls[1] == 1
|
||||
|
||||
Reference in New Issue
Block a user