diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/termination-with-intervention.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/termination-with-intervention.ipynb index 4bed96d6b..554dbf0bf 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/termination-with-intervention.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/termination-with-intervention.ipynb @@ -23,7 +23,6 @@ "from typing import Any\n", "\n", "from autogen_core import (\n", - " AgentId,\n", " DefaultInterventionHandler,\n", " DefaultTopicId,\n", " MessageContext,\n", @@ -100,7 +99,7 @@ " def __init__(self) -> None:\n", " self._termination_value: Termination | None = None\n", "\n", - " async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n", + " async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:\n", " if isinstance(message, Termination):\n", " self._termination_value = message\n", " return message\n", @@ -171,7 +170,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/tool-use-with-intervention.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/tool-use-with-intervention.ipynb index 37def894b..2006b2a86 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/tool-use-with-intervention.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/tool-use-with-intervention.ipynb @@ -131,7 +131,9 @@ "outputs": [], "source": [ "class ToolInterventionHandler(DefaultInterventionHandler):\n", - " async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n", + " async def on_send(\n", + " self, message: Any, *, message_context: MessageContext, recipient: AgentId\n", + " ) -> Any | type[DropMessage]:\n", " if isinstance(message, FunctionCall):\n", " # Request user prompt for tool execution.\n", " user_input = input(\n", diff --git a/python/packages/autogen-core/samples/slow_human_in_loop.py b/python/packages/autogen-core/samples/slow_human_in_loop.py index eb4e627bb..8762b588a 100644 --- a/python/packages/autogen-core/samples/slow_human_in_loop.py +++ b/python/packages/autogen-core/samples/slow_human_in_loop.py @@ -31,7 +31,6 @@ from dataclasses import dataclass from typing import Any, Mapping, Optional from autogen_core import ( - AgentId, CancellationToken, DefaultInterventionHandler, DefaultTopicId, @@ -211,7 +210,7 @@ class NeedsUserInputHandler(DefaultInterventionHandler): def __init__(self): self.question_for_user: GetSlowUserMessage | None = None - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any: if isinstance(message, GetSlowUserMessage): self.question_for_user = message return message @@ -231,7 +230,7 @@ class TerminationHandler(DefaultInterventionHandler): def __init__(self): self.terminateMessage: TerminateMessage | None = None - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any: if isinstance(message, TerminateMessage): self.terminateMessage = message return message diff --git a/python/packages/autogen-core/src/autogen_core/_intervention.py b/python/packages/autogen-core/src/autogen_core/_intervention.py index 649026fea..752c831b8 100644 --- a/python/packages/autogen-core/src/autogen_core/_intervention.py +++ b/python/packages/autogen-core/src/autogen_core/_intervention.py @@ -1,6 +1,7 @@ from typing import Any, Protocol, final from ._agent_id import AgentId +from ._message_context import MessageContext __all__ = [ "DropMessage", @@ -10,20 +11,59 @@ __all__ = [ @final -class DropMessage: ... +class DropMessage: + """Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler.""" + + ... class InterventionHandler(Protocol): """An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`. + The handler is called when the message is submitted to the runtime. + + Currently the only runtime which supports this is the :class:`autogen_core.base.SingleThreadedAgentRuntime`. + Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly. + + Example: + + .. code-block:: python + + from autogen_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime + from dataclasses import dataclass + from typing import Any + + + @dataclass + class MyMessage: + content: str + + + class MyInterventionHandler(DefaultInterventionHandler): + async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage: + if isinstance(message, MyMessage): + message.content = message.content.upper() + return message + + + runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()]) + """ - async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ... - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ... - async def on_response( - self, message: Any, *, sender: AgentId, recipient: AgentId | None - ) -> Any | type[DropMessage]: ... + async def on_send( + self, message: Any, *, message_context: MessageContext, recipient: AgentId + ) -> Any | type[DropMessage]: + """Called when a message is submitted to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.send_message`.""" + ... + + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]: + """Called when a message is published to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.publish_message`.""" + ... + + async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]: + """Called when a response is received by the AgentRuntime from an Agent's message handler returning a value.""" + ... class DefaultInterventionHandler(InterventionHandler): @@ -31,10 +71,12 @@ class DefaultInterventionHandler(InterventionHandler): handler methods, that simply returns the message unchanged. Allows for easy subclassing to override only the desired methods.""" - async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: + async def on_send( + self, message: Any, *, message_context: MessageContext, recipient: AgentId + ) -> Any | type[DropMessage]: return message - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]: return message async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]: diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 9c292b9f2..f8f366921 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -474,7 +474,16 @@ class SingleThreadedAgentRuntime(AgentRuntime): "intercept", handler.__class__.__name__, parent=message_envelope.metadata ): try: - temp_message = await handler.on_send(message, sender=sender, recipient=recipient) + message_context = MessageContext( + sender=sender, + topic_id=None, + is_rpc=True, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + temp_message = await handler.on_send( + message, message_context=message_context, recipient=recipient + ) _warn_if_none(temp_message, "on_send") except BaseException as e: future.set_exception(e) @@ -506,7 +515,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): "intercept", handler.__class__.__name__, parent=message_envelope.metadata ): try: - temp_message = await handler.on_publish(message, sender=sender) + message_context = MessageContext( + sender=sender, + topic_id=topic_id, + is_rpc=False, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) + temp_message = await handler.on_publish(message, message_context=message_context) _warn_if_none(temp_message, "on_publish") except BaseException as e: # TODO: we should raise the intervention exception to the publisher. diff --git a/python/packages/autogen-core/tests/test_intervention.py b/python/packages/autogen-core/tests/test_intervention.py index ef6ee4ebf..fdd5654ff 100644 --- a/python/packages/autogen-core/tests/test_intervention.py +++ b/python/packages/autogen-core/tests/test_intervention.py @@ -1,5 +1,15 @@ +from typing import Any + import pytest -from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, SingleThreadedAgentRuntime +from autogen_core import ( + AgentId, + DefaultInterventionHandler, + DefaultSubscription, + DefaultTopicId, + DropMessage, + MessageContext, + SingleThreadedAgentRuntime, +) from autogen_core.exceptions import MessageDroppedException from autogen_test_utils import LoopbackAgent, MessageType @@ -8,10 +18,20 @@ from autogen_test_utils import LoopbackAgent, MessageType async def test_intervention_count_messages() -> None: class DebugInterventionHandler(DefaultInterventionHandler): def __init__(self) -> None: - self.num_messages = 0 + self.num_send_messages = 0 + self.num_publish_messages = 0 + self.num_response_messages = 0 - async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType: - self.num_messages += 1 + async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> Any: + self.num_send_messages += 1 + return message + + async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any: + self.num_publish_messages += 1 + return message + + async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any: + self.num_response_messages += 1 return message handler = DebugInterventionHandler() @@ -22,18 +42,28 @@ async def test_intervention_count_messages() -> None: _response = await runtime.send_message(MessageType(), recipient=loopback) - await runtime.stop() + await runtime.stop_when_idle() - assert handler.num_messages == 1 + assert handler.num_send_messages == 1 + assert handler.num_response_messages == 1 loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert loopback_agent.num_calls == 1 + runtime.start() + await runtime.add_subscription(DefaultSubscription(agent_type="name")) + + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + + await runtime.stop_when_idle() + assert loopback_agent.num_calls == 2 + assert handler.num_publish_messages == 1 + @pytest.mark.asyncio async def test_intervention_drop_send() -> None: class DropSendInterventionHandler(DefaultInterventionHandler): async def on_send( - self, message: MessageType, *, sender: AgentId | None, recipient: AgentId + self, message: MessageType, *, message_context: MessageContext, recipient: AgentId ) -> MessageType | type[DropMessage]: return DropMessage @@ -81,7 +111,7 @@ async def test_intervention_raise_exception_on_send() -> None: class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore async def on_send( - self, message: MessageType, *, sender: AgentId | None, recipient: AgentId + self, message: MessageType, *, message_context: MessageContext, recipient: AgentId ) -> MessageType | type[DropMessage]: # type: ignore raise InterventionException