Allow background exceptions to be fatal (#5716)

Closes #4904 

Does not change default behavior in core.

In agentchat, this change will mean that exceptions that used to be
ignored and result in bugs like the group chat stopping are now reported
out to the user application.

---------

Co-authored-by: Ben Constable <benconstable@microsoft.com>
Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
Jack Gerrits
2025-02-26 13:34:53 -05:00
committed by GitHub
parent dc55ec964b
commit 6b68719939
4 changed files with 129 additions and 21 deletions

View File

@@ -70,7 +70,8 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
# Create a runtime for the team.
# TODO: The runtime should be created by a managed context.
self._runtime = SingleThreadedAgentRuntime()
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
# Flag to track if the group chat has been initialized.
self._initialized = False
@@ -408,8 +409,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
# Start a coroutine to stop the runtime and signal the output message queue is complete.
async def stop_runtime() -> None:
await self._runtime.stop_when_idle()
await self._output_message_queue.put(None)
try:
await self._runtime.stop_when_idle()
finally:
await self._output_message_queue.put(None)
shutdown_task = asyncio.create_task(stop_runtime())
@@ -444,14 +447,17 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
finally:
# Wait for the shutdown task to finish.
await shutdown_task
try:
# This will propagate any exceptions raised in the shutdown task.
# We need to ensure we cleanup though.
await shutdown_task
finally:
# Clear the output message queue.
while not self._output_message_queue.empty():
self._output_message_queue.get_nowait()
# Clear the output message queue.
while not self._output_message_queue.empty():
self._output_message_queue.get_nowait()
# Indicate that the team is no longer running.
self._is_running = False
# Indicate that the team is no longer running.
self._is_running = False
async def reset(self) -> None:
"""Reset the team and its participants to their initial state.

View File

@@ -100,6 +100,27 @@ class _EchoAgent(BaseChatAgent):
self._last_message = None
class _FlakyAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self._last_message: str | None = None
self._total_messages = 0
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
return (TextMessage,)
@property
def total_messages(self) -> int:
return self._total_messages
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
raise ValueError("I am a flaky agent...")
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self._last_message = None
class _StopAgent(_EchoAgent):
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
super().__init__(name, description)
@@ -400,6 +421,23 @@ async def test_round_robin_group_chat_with_resume_and_reset() -> None:
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_exception_raised() -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
termination = MaxMessageTermination(3)
team = RoundRobinGroupChat(
participants=[agent_1, agent_2, agent_3],
termination_condition=termination,
)
with pytest.raises(ValueError, match="I am a flaky agent..."):
await team.run(
task="Write a program that prints 'Hello, world!'",
)
@pytest.mark.asyncio
async def test_round_robin_group_chat_max_turn() -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")

View File

@@ -159,6 +159,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
intervention_handlers (List[InterventionHandler], optional): A list of intervention
handlers that can intercept messages before they are sent or published. Defaults to None.
tracer_provider (TracerProvider, optional): The tracer provider to use for tracing. Defaults to None.
ignore_unhandled_exceptions (bool, optional): Whether to ignore unhandled exceptions in that occur in agent event handlers. Any background exceptions will be raised on the next call to `process_next` or from an awaited `stop`, `stop_when_idle` or `stop_when`. Note, this does not apply to RPC handlers. Defaults to True.
Examples:
@@ -248,6 +249,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
*,
intervention_handlers: List[InterventionHandler] | None = None,
tracer_provider: TracerProvider | None = None,
ignore_unhandled_exceptions: bool = True,
) -> None:
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue()
@@ -261,6 +263,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None
@property
def unprocessed_messages_count(
@@ -515,15 +519,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
exception=e,
)
)
raise
raise e
future = _on_message(agent, message_context)
responses.append(future)
await asyncio.gather(*responses)
except BaseException:
# Ignore exceptions raised during publishing. We've already logged them above.
pass
except BaseException as e:
if not self._ignore_unhandled_handler_exceptions:
self._background_exception = e
finally:
self._message_queue.task_done()
# TODO if responses are given for a publish
@@ -552,15 +556,28 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._message_queue.task_done()
async def process_next(self) -> None:
"""Process the next message in the queue."""
"""Process the next message in the queue.
If there is an unhandled exception in the background task, it will be raised here. `process_next` cannot be called again after an unhandled exception is raised.
"""
await self._process_next()
async def _process_next(self) -> None:
"""Process the next message in the queue."""
if self._background_exception is not None:
e = self._background_exception
self._background_exception = None
self._message_queue.shutdown(immediate=True) # type: ignore
raise e
try:
message_envelope = await self._message_queue.get()
except QueueShutDown:
if self._background_exception is not None:
e = self._background_exception
self._background_exception = None
raise e from None
return
match message_envelope:
@@ -637,6 +654,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return
message_envelope.message = temp_message
task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
@@ -711,19 +729,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop()
self._run_context = None
self._message_queue = Queue()
try:
await self._run_context.stop()
finally:
self._run_context = None
self._message_queue = Queue()
async def stop_when_idle(self) -> None:
"""Stop the runtime message processing loop when there is
no outstanding message being processed or queued. This is the most common way to stop the runtime."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when_idle()
self._run_context = None
self._message_queue = Queue()
try:
await self._run_context.stop_when_idle()
finally:
self._run_context = None
self._message_queue = Queue()
async def stop_when(self, condition: Callable[[], bool]) -> None:
"""Stop the runtime message processing loop when the condition is met.

View File

@@ -6,12 +6,16 @@ from autogen_core import (
AgentInstantiationContext,
AgentType,
DefaultTopicId,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
TopicId,
TypeSubscription,
event,
try_get_known_serializers_for_type,
type_subscription,
)
from autogen_core._default_subscription import default_subscription
from autogen_test_utils import (
CascadingAgent,
CascadingMessageType,
@@ -268,3 +272,41 @@ async def test_default_subscription_publish_to_other_source() -> None:
assert other_long_running_agent.num_calls == 1
await runtime.close()
@default_subscription
class FailingAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("A failing agent.")
@event
async def on_new_message_event(self, message: MessageType, ctx: MessageContext) -> None:
raise ValueError("Test exception")
@pytest.mark.asyncio
async def test_event_handler_exception_propogates() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
await FailingAgent.register(runtime, "name", FailingAgent)
with pytest.raises(ValueError, match="Test exception"):
runtime.start()
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
await runtime.close()
@pytest.mark.asyncio
async def test_event_handler_exception_multi_message() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
await FailingAgent.register(runtime, "name", FailingAgent)
with pytest.raises(ValueError, match="Test exception"):
runtime.start()
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
await runtime.close()