mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Allow background exceptions to be fatal (#5716)
Closes #4904 Does not change default behavior in core. In agentchat, this change will mean that exceptions that used to be ignored and result in bugs like the group chat stopping are now reported out to the user application. --------- Co-authored-by: Ben Constable <benconstable@microsoft.com> Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
@@ -70,7 +70,8 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
|
||||
# Create a runtime for the team.
|
||||
# TODO: The runtime should be created by a managed context.
|
||||
self._runtime = SingleThreadedAgentRuntime()
|
||||
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
|
||||
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
|
||||
# Flag to track if the group chat has been initialized.
|
||||
self._initialized = False
|
||||
@@ -408,8 +409,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
|
||||
# Start a coroutine to stop the runtime and signal the output message queue is complete.
|
||||
async def stop_runtime() -> None:
|
||||
await self._runtime.stop_when_idle()
|
||||
await self._output_message_queue.put(None)
|
||||
try:
|
||||
await self._runtime.stop_when_idle()
|
||||
finally:
|
||||
await self._output_message_queue.put(None)
|
||||
|
||||
shutdown_task = asyncio.create_task(stop_runtime())
|
||||
|
||||
@@ -444,14 +447,17 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
|
||||
finally:
|
||||
# Wait for the shutdown task to finish.
|
||||
await shutdown_task
|
||||
try:
|
||||
# This will propagate any exceptions raised in the shutdown task.
|
||||
# We need to ensure we cleanup though.
|
||||
await shutdown_task
|
||||
finally:
|
||||
# Clear the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
# Clear the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and its participants to their initial state.
|
||||
|
||||
@@ -100,6 +100,27 @@ class _EchoAgent(BaseChatAgent):
|
||||
self._last_message = None
|
||||
|
||||
|
||||
class _FlakyAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
super().__init__(name, description)
|
||||
self._last_message: str | None = None
|
||||
self._total_messages = 0
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
||||
return (TextMessage,)
|
||||
|
||||
@property
|
||||
def total_messages(self) -> int:
|
||||
return self._total_messages
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
raise ValueError("I am a flaky agent...")
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
self._last_message = None
|
||||
|
||||
|
||||
class _StopAgent(_EchoAgent):
|
||||
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
|
||||
super().__init__(name, description)
|
||||
@@ -400,6 +421,23 @@ async def test_round_robin_group_chat_with_resume_and_reset() -> None:
|
||||
assert result.stop_reason is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_exception_raised() -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
|
||||
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
||||
termination = MaxMessageTermination(3)
|
||||
team = RoundRobinGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3],
|
||||
termination_condition=termination,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="I am a flaky agent..."):
|
||||
await team.run(
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_max_turn() -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
|
||||
@@ -159,6 +159,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
intervention_handlers (List[InterventionHandler], optional): A list of intervention
|
||||
handlers that can intercept messages before they are sent or published. Defaults to None.
|
||||
tracer_provider (TracerProvider, optional): The tracer provider to use for tracing. Defaults to None.
|
||||
ignore_unhandled_exceptions (bool, optional): Whether to ignore unhandled exceptions in that occur in agent event handlers. Any background exceptions will be raised on the next call to `process_next` or from an awaited `stop`, `stop_when_idle` or `stop_when`. Note, this does not apply to RPC handlers. Defaults to True.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -248,6 +249,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
*,
|
||||
intervention_handlers: List[InterventionHandler] | None = None,
|
||||
tracer_provider: TracerProvider | None = None,
|
||||
ignore_unhandled_exceptions: bool = True,
|
||||
) -> None:
|
||||
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
|
||||
self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue()
|
||||
@@ -261,6 +263,8 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._run_context: RunContext | None = None
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
|
||||
self._background_exception: BaseException | None = None
|
||||
|
||||
@property
|
||||
def unprocessed_messages_count(
|
||||
@@ -515,15 +519,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
raise
|
||||
raise e
|
||||
|
||||
future = _on_message(agent, message_context)
|
||||
responses.append(future)
|
||||
|
||||
await asyncio.gather(*responses)
|
||||
except BaseException:
|
||||
# Ignore exceptions raised during publishing. We've already logged them above.
|
||||
pass
|
||||
except BaseException as e:
|
||||
if not self._ignore_unhandled_handler_exceptions:
|
||||
self._background_exception = e
|
||||
finally:
|
||||
self._message_queue.task_done()
|
||||
# TODO if responses are given for a publish
|
||||
@@ -552,15 +556,28 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
self._message_queue.task_done()
|
||||
|
||||
async def process_next(self) -> None:
|
||||
"""Process the next message in the queue."""
|
||||
"""Process the next message in the queue.
|
||||
|
||||
If there is an unhandled exception in the background task, it will be raised here. `process_next` cannot be called again after an unhandled exception is raised.
|
||||
"""
|
||||
await self._process_next()
|
||||
|
||||
async def _process_next(self) -> None:
|
||||
"""Process the next message in the queue."""
|
||||
|
||||
if self._background_exception is not None:
|
||||
e = self._background_exception
|
||||
self._background_exception = None
|
||||
self._message_queue.shutdown(immediate=True) # type: ignore
|
||||
raise e
|
||||
|
||||
try:
|
||||
message_envelope = await self._message_queue.get()
|
||||
except QueueShutDown:
|
||||
if self._background_exception is not None:
|
||||
e = self._background_exception
|
||||
self._background_exception = None
|
||||
raise e from None
|
||||
return
|
||||
|
||||
match message_envelope:
|
||||
@@ -637,6 +654,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
|
||||
task = asyncio.create_task(self._process_publish(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
@@ -711,19 +729,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
|
||||
await self._run_context.stop()
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
try:
|
||||
await self._run_context.stop()
|
||||
finally:
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
"""Stop the runtime message processing loop when there is
|
||||
no outstanding message being processed or queued. This is the most common way to stop the runtime."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when_idle()
|
||||
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
try:
|
||||
await self._run_context.stop_when_idle()
|
||||
finally:
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
"""Stop the runtime message processing loop when the condition is met.
|
||||
|
||||
@@ -6,12 +6,16 @@ from autogen_core import (
|
||||
AgentInstantiationContext,
|
||||
AgentType,
|
||||
DefaultTopicId,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
TopicId,
|
||||
TypeSubscription,
|
||||
event,
|
||||
try_get_known_serializers_for_type,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core._default_subscription import default_subscription
|
||||
from autogen_test_utils import (
|
||||
CascadingAgent,
|
||||
CascadingMessageType,
|
||||
@@ -268,3 +272,41 @@ async def test_default_subscription_publish_to_other_source() -> None:
|
||||
assert other_long_running_agent.num_calls == 1
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@default_subscription
|
||||
class FailingAgent(RoutedAgent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("A failing agent.")
|
||||
|
||||
@event
|
||||
async def on_new_message_event(self, message: MessageType, ctx: MessageContext) -> None:
|
||||
raise ValueError("Test exception")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_handler_exception_propogates() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
await FailingAgent.register(runtime, "name", FailingAgent)
|
||||
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
runtime.start()
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
await runtime.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_handler_exception_multi_message() -> None:
|
||||
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
await FailingAgent.register(runtime, "name", FailingAgent)
|
||||
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
runtime.start()
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
await runtime.close()
|
||||
|
||||
Reference in New Issue
Block a user