Implementation for @rpc and @event decorators (#504)

* Implementation for `@rpc` and `@event` decorators

* update
This commit is contained in:
Eric Zhu
2024-09-16 10:20:44 -07:00
committed by GitHub
parent 561897b4ee
commit aecb437d85
5 changed files with 347 additions and 28 deletions

View File

@@ -6,7 +6,7 @@ from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._routed_agent import RoutedAgent, TypeRoutedAgent, message_handler
from ._routed_agent import RoutedAgent, TypeRoutedAgent, message_handler, event, rpc
from ._type_subscription import TypeSubscription
from ._types import FunctionCall
@@ -16,6 +16,8 @@ __all__ = [
"TypeRoutedAgent",
"ClosureAgent",
"message_handler",
"event",
"rpc",
"FunctionCall",
"TypeSubscription",
"DefaultSubscription",

View File

@@ -91,9 +91,9 @@ def message_handler(
]
| MessageHandler[ReceivesT, ProducesT]
):
"""Decorator for message handlers.
"""Decorator for generic message handlers.
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle messages.
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle both event and RPC messages.
These methods must have a specific signature that needs to be followed for it to be valid:
- The method must be an `async` method.
@@ -169,19 +169,259 @@ def message_handler(
raise ValueError("Invalid arguments")
@overload
def event(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[ReceivesT, None]: ...
@overload
def event(
func: None = None,
*,
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[ReceivesT, None],
]: ...
@overload
def event(
func: None = None,
*,
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[ReceivesT, None],
]: ...
def event(
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]]],
MessageHandler[ReceivesT, None],
]
| MessageHandler[ReceivesT, None]
):
"""Decorator for event message handlers.
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle event messages.
These methods must have a specific signature that needs to be followed for it to be valid:
- The method must be an `async` method.
- The method must be decorated with the `@message_handler` decorator.
- The method must have exactly 3 arguments:
1. `self`
2. `message`: The event message to be handled, this must be type-hinted with the message type that it is intended to handle.
3. `ctx`: A :class:`autogen_core.base.MessageContext` object.
- The method must return `None`.
Handlers can handle more than one message type by accepting a Union of the message types.
Args:
func: The function to be decorated.
strict: If `True`, the handler will raise an exception if the message type is not in the target types. If `False`, it will log a warning instead.
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
"""
def decorator(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, None]],
) -> MessageHandler[ReceivesT, None]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found. Please provide a type hint for the message parameter.")
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found. Please use `None` as the type hint of the return type.")
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> None:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, ctx) # type: ignore
if return_value is not None:
if strict:
raise ValueError(f"Return type {type(return_value)} is not None.")
else:
logger.warning(f"Return type {type(return_value)} is not None. It will be ignored.")
return None
wrapper_handler = cast(MessageHandler[ReceivesT, None], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
# Wrap the match function with a check on the is_rpc flag.
wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True)
return wrapper_handler
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")
@overload
def rpc(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]: ...
@overload
def rpc(
func: None = None,
*,
match: None = ...,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]: ...
@overload
def rpc(
func: None = None,
*,
match: Callable[[ReceivesT, MessageContext], bool],
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]: ...
def rpc(
func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
match: None | Callable[[ReceivesT, MessageContext], bool] = None,
) -> (
Callable[
[Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]
| MessageHandler[ReceivesT, ProducesT]
):
"""Decorator for RPC message handlers.
Add this decorator to methods in a :class:`RoutedAgent` class that are intended to handle RPC messages.
These methods must have a specific signature that needs to be followed for it to be valid:
- The method must be an `async` method.
- The method must be decorated with the `@message_handler` decorator.
- The method must have exactly 3 arguments:
1. `self`
2. `message`: The message to be handled, this must be type-hinted with the message type that it is intended to handle.
3. `ctx`: A :class:`autogen_core.base.MessageContext` object.
- The method must be type hinted with what message types it can return as a response, or it can return `None` if it does not return anything.
Handlers can handle more than one message type by accepting a Union of the message types. It can also return more than one message type by returning a Union of the message types.
Args:
func: The function to be decorated.
strict: If `True`, the handler will raise an exception if the message type or return type is not in the target types. If `False`, it will log a warning instead.
match: A function that takes the message and the context as arguments and returns a boolean. This is used for secondary routing after the message type. For handlers addressing the same message type, the match function is applied in alphabetical order of the handlers and the first matching handler will be called while the rest are skipped. If `None`, the first handler in alphabetical order matching the same message type will be called.
"""
def decorator(
func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")
if "return" not in type_hints:
raise AssertionError("return not found in function signature")
# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found")
# print(type_hints)
return_types = get_types(type_hints["return"])
if return_types is None:
raise AssertionError("Return type not found")
# Convert target_types to list and stash
@wraps(func)
async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT:
if type(message) not in target_types:
if strict:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")
return_value = await func(self, message, ctx)
if AnyType not in return_types and type(return_value) not in return_types:
if strict:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
else:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
return return_value
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True
wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True)
return wrapper_handler
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")
class RoutedAgent(BaseAgent):
"""A base class for agents that route messages to handlers based on the type of the message
and optional matching functions.
To create a routed agent, subclass this class and add message handlers as methods decorated with
the :func:`message_handler` decorator.
either :func:`event` or :func:`rpc` decorator.
Example:
.. code-block:: python
from autogen_core.base import MessageContext
from autogen_core.components import RoutedAgent, message_handler
from autogen_core.components import RoutedAgent, event, rpc
# Assume Message, MessageWithContent, and Response are defined elsewhere.
@@ -189,12 +429,12 @@ class RoutedAgent(BaseAgent):
def __init__(self):
super().__init__("MyAgent")
@message_handler
async def handle_message(self, message: Message, ctx: MessageContext) -> Response:
return Response()
@event
async def handle_event_message(self, message: Message, ctx: MessageContext) -> None:
self.publish_message(MessageWithContent("event handled"), ctx.topic_id)
@message_handler(match=lambda message, ctx: message.content == "special")
async def handle_special_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
@rpc(match=lambda message, ctx: message.content == "special")
async def handle_special_rpc_message(self, message: MessageWithContent, ctx: MessageContext) -> Response:
return Response()
"""
@@ -223,7 +463,7 @@ class RoutedAgent(BaseAgent):
async def on_message(self, message: Any, ctx: MessageContext) -> Any | None:
"""Handle a message by routing it to the appropriate message handler.
Do not override this method in subclasses. Instead, add message handlers as methods decorated with
the :func:`message_handler` decorator."""
either the :func:`event` or :func:`rpc` decorator."""
key_type: Type[Any] = type(message) # type: ignore
handlers = self._handlers.get(key_type) # type: ignore
if handlers is not None:

View File

@@ -5,7 +5,7 @@ from typing import Callable, cast
import pytest
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, MessageContext, TopicId
from autogen_core.components import RoutedAgent, TypeSubscription, message_handler
from autogen_core.components import RoutedAgent, TypeSubscription, message_handler, event, rpc
from test_utils import LoopbackAgent
@@ -115,3 +115,95 @@ async def test_routed_agent_message_matching() -> None:
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
assert agent.handler_one_called is True
assert agent.handler_two_called is True
class EventAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("An event agent.")
self.num_calls = [0, 0]
@event(match=lambda msg, ctx: msg.value == "one") # type: ignore
async def on_event_one(self, message: TestMessage, ctx: MessageContext) -> None:
self.num_calls[0] += 1
@event(match=lambda msg, ctx: msg.value == "two") # type: ignore
async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None:
self.num_calls[1] += 1
@pytest.mark.asyncio
async def test_event() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", EventAgent, lambda: [TypeSubscription("default", "counter")])
agent_id = AgentId(type="counter", key="default")
# Send a broadcast message.
runtime.start()
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1
assert agent.num_calls[1] == 0
# Send another broadcast message.
runtime.start()
await runtime.publish_message(TestMessage("two"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1
assert agent.num_calls[1] == 1
# Send an RPC message, expect no change.
runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1
assert agent.num_calls[1] == 1
class RPCAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("An RPC agent.")
self.num_calls = [0, 0]
@rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore
async def on_rpc_one(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
self.num_calls[0] += 1
return message
@rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore
async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
self.num_calls[1] += 1
return message
@pytest.mark.asyncio
async def test_rpc() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", RPCAgent, lambda: [TypeSubscription("default", "counter")])
agent_id = AgentId(type="counter", key="default")
# Send an RPC message.
runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1
assert agent.num_calls[1] == 0
# Send another RPC message.
runtime.start()
await runtime.send_message(TestMessage("two"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1
assert agent.num_calls[1] == 1
# Send a broadcast message, expect no change.
runtime.start()
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1
assert agent.num_calls[1] == 1