diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index e489975536..827af68a56 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -70,7 +70,8 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): # Create a runtime for the team. # TODO: The runtime should be created by a managed context. - self._runtime = SingleThreadedAgentRuntime() + # Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination. + self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) # Flag to track if the group chat has been initialized. self._initialized = False @@ -408,8 +409,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): # Start a coroutine to stop the runtime and signal the output message queue is complete. async def stop_runtime() -> None: - await self._runtime.stop_when_idle() - await self._output_message_queue.put(None) + try: + await self._runtime.stop_when_idle() + finally: + await self._output_message_queue.put(None) shutdown_task = asyncio.create_task(stop_runtime()) @@ -444,14 +447,17 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): finally: # Wait for the shutdown task to finish. - await shutdown_task + try: + # This will propagate any exceptions raised in the shutdown task. + # We need to ensure we cleanup though. + await shutdown_task + finally: + # Clear the output message queue. + while not self._output_message_queue.empty(): + self._output_message_queue.get_nowait() - # Clear the output message queue. - while not self._output_message_queue.empty(): - self._output_message_queue.get_nowait() - - # Indicate that the team is no longer running. - self._is_running = False + # Indicate that the team is no longer running. + self._is_running = False async def reset(self) -> None: """Reset the team and its participants to their initial state. diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index d12b24f04a..5c1c5a6aa6 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -100,6 +100,27 @@ class _EchoAgent(BaseChatAgent): self._last_message = None +class _FlakyAgent(BaseChatAgent): + def __init__(self, name: str, description: str) -> None: + super().__init__(name, description) + self._last_message: str | None = None + self._total_messages = 0 + + @property + def produced_message_types(self) -> Sequence[type[ChatMessage]]: + return (TextMessage,) + + @property + def total_messages(self) -> int: + return self._total_messages + + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + raise ValueError("I am a flaky agent...") + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + self._last_message = None + + class _StopAgent(_EchoAgent): def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None: super().__init__(name, description) @@ -400,6 +421,23 @@ async def test_round_robin_group_chat_with_resume_and_reset() -> None: assert result.stop_reason is not None +@pytest.mark.asyncio +async def test_round_robin_group_chat_with_exception_raised() -> None: + agent_1 = _EchoAgent("agent_1", description="echo agent 1") + agent_2 = _FlakyAgent("agent_2", description="echo agent 2") + agent_3 = _EchoAgent("agent_3", description="echo agent 3") + termination = MaxMessageTermination(3) + team = RoundRobinGroupChat( + participants=[agent_1, agent_2, agent_3], + termination_condition=termination, + ) + + with pytest.raises(ValueError, match="I am a flaky agent..."): + await team.run( + task="Write a program that prints 'Hello, world!'", + ) + + @pytest.mark.asyncio async def test_round_robin_group_chat_max_turn() -> None: agent_1 = _EchoAgent("agent_1", description="echo agent 1") diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index f80be46a58..65ed6a8fb3 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -159,6 +159,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): intervention_handlers (List[InterventionHandler], optional): A list of intervention handlers that can intercept messages before they are sent or published. Defaults to None. tracer_provider (TracerProvider, optional): The tracer provider to use for tracing. Defaults to None. + ignore_unhandled_exceptions (bool, optional): Whether to ignore unhandled exceptions in that occur in agent event handlers. Any background exceptions will be raised on the next call to `process_next` or from an awaited `stop`, `stop_when_idle` or `stop_when`. Note, this does not apply to RPC handlers. Defaults to True. Examples: @@ -248,6 +249,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): *, intervention_handlers: List[InterventionHandler] | None = None, tracer_provider: TracerProvider | None = None, + ignore_unhandled_exceptions: bool = True, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue() @@ -261,6 +263,8 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._subscription_manager = SubscriptionManager() self._run_context: RunContext | None = None self._serialization_registry = SerializationRegistry() + self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions + self._background_exception: BaseException | None = None @property def unprocessed_messages_count( @@ -515,15 +519,15 @@ class SingleThreadedAgentRuntime(AgentRuntime): exception=e, ) ) - raise + raise e future = _on_message(agent, message_context) responses.append(future) await asyncio.gather(*responses) - except BaseException: - # Ignore exceptions raised during publishing. We've already logged them above. - pass + except BaseException as e: + if not self._ignore_unhandled_handler_exceptions: + self._background_exception = e finally: self._message_queue.task_done() # TODO if responses are given for a publish @@ -552,15 +556,28 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._message_queue.task_done() async def process_next(self) -> None: - """Process the next message in the queue.""" + """Process the next message in the queue. + + If there is an unhandled exception in the background task, it will be raised here. `process_next` cannot be called again after an unhandled exception is raised. + """ await self._process_next() async def _process_next(self) -> None: """Process the next message in the queue.""" + if self._background_exception is not None: + e = self._background_exception + self._background_exception = None + self._message_queue.shutdown(immediate=True) # type: ignore + raise e + try: message_envelope = await self._message_queue.get() except QueueShutDown: + if self._background_exception is not None: + e = self._background_exception + self._background_exception = None + raise e from None return match message_envelope: @@ -637,6 +654,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): return message_envelope.message = temp_message + task = asyncio.create_task(self._process_publish(message_envelope)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -711,19 +729,23 @@ class SingleThreadedAgentRuntime(AgentRuntime): if self._run_context is None: raise RuntimeError("Runtime is not started") - await self._run_context.stop() - self._run_context = None - self._message_queue = Queue() + try: + await self._run_context.stop() + finally: + self._run_context = None + self._message_queue = Queue() async def stop_when_idle(self) -> None: """Stop the runtime message processing loop when there is no outstanding message being processed or queued. This is the most common way to stop the runtime.""" if self._run_context is None: raise RuntimeError("Runtime is not started") - await self._run_context.stop_when_idle() - self._run_context = None - self._message_queue = Queue() + try: + await self._run_context.stop_when_idle() + finally: + self._run_context = None + self._message_queue = Queue() async def stop_when(self, condition: Callable[[], bool]) -> None: """Stop the runtime message processing loop when the condition is met. diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 9a0e275074..64a1cccf4b 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -6,12 +6,16 @@ from autogen_core import ( AgentInstantiationContext, AgentType, DefaultTopicId, + MessageContext, + RoutedAgent, SingleThreadedAgentRuntime, TopicId, TypeSubscription, + event, try_get_known_serializers_for_type, type_subscription, ) +from autogen_core._default_subscription import default_subscription from autogen_test_utils import ( CascadingAgent, CascadingMessageType, @@ -268,3 +272,41 @@ async def test_default_subscription_publish_to_other_source() -> None: assert other_long_running_agent.num_calls == 1 await runtime.close() + + +@default_subscription +class FailingAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("A failing agent.") + + @event + async def on_new_message_event(self, message: MessageType, ctx: MessageContext) -> None: + raise ValueError("Test exception") + + +@pytest.mark.asyncio +async def test_event_handler_exception_propogates() -> None: + runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + await FailingAgent.register(runtime, "name", FailingAgent) + + with pytest.raises(ValueError, match="Test exception"): + runtime.start() + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.stop_when_idle() + + await runtime.close() + + +@pytest.mark.asyncio +async def test_event_handler_exception_multi_message() -> None: + runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + await FailingAgent.register(runtime, "name", FailingAgent) + + with pytest.raises(ValueError, match="Test exception"): + runtime.start() + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.stop_when_idle() + + await runtime.close()