Team termination condition sets in the constructor (#4042)

* Termination condition as part of constructor

* Update doc

* Update notebooks
This commit is contained in:
Eric Zhu
2024-11-01 15:49:37 -07:00
committed by GitHub
parent 7d1857dae6
commit 4fec22ddc5
13 changed files with 134 additions and 97 deletions

View File

@@ -33,5 +33,5 @@ class TaskRunner(Protocol):
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
as the last item in the stream."""
:class:`TaskResult` as the last item in the stream."""
...

View File

@@ -33,6 +33,7 @@ class BaseGroupChat(Team, ABC):
self,
participants: List[ChatAgent],
group_chat_manager_class: type[BaseGroupChatManager],
termination_condition: TerminationCondition | None = None,
):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
@@ -41,6 +42,7 @@ class BaseGroupChat(Team, ABC):
self._participants = participants
self._team_id = str(uuid.uuid4())
self._base_group_chat_manager_class = group_chat_manager_class
self._termination_condition = termination_condition
@abstractmethod
def _create_group_chat_manager_factory(
@@ -72,12 +74,12 @@ class BaseGroupChat(Team, ABC):
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
:meth:`run_stream` to run the team and then returns the final result."""
async for message in self.run_stream(
task, cancellation_token=cancellation_token, termination_condition=termination_condition
task,
cancellation_token=cancellation_token,
):
if isinstance(message, TaskResult):
return message
@@ -88,10 +90,9 @@ class BaseGroupChat(Team, ABC):
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
as the last item in the stream."""
of the type :class:`TaskResult` as the last item in the stream."""
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
@@ -131,7 +132,7 @@ class BaseGroupChat(Team, ABC):
group_topic_type=group_topic_type,
participant_topic_types=participant_topic_types,
participant_descriptions=participant_descriptions,
termination_condition=termination_condition,
termination_condition=self._termination_condition,
),
)
# Add subscriptions for the group chat manager.

View File

@@ -50,7 +50,8 @@ class RoundRobinGroupChat(BaseGroupChat):
Args:
participants (List[BaseChatAgent]): The participants in the group chat.
tools (List[Tool], optional): The tools to use in the group chat. Defaults to None.
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
Raises:
ValueError: If no participants are provided or if participant names are not unique.
@@ -65,7 +66,7 @@ class RoundRobinGroupChat(BaseGroupChat):
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
from autogen_agentchat.task import TextMentionTermination
async def main() -> None:
@@ -79,8 +80,9 @@ class RoundRobinGroupChat(BaseGroupChat):
model_client=model_client,
tools=[get_weather],
)
team = RoundRobinGroupChat([assistant])
stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination())
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([assistant], termination_condition=termination)
stream = team.run_stream("What's the weather in New York?")
async for message in stream:
print(message)
@@ -95,7 +97,7 @@ class RoundRobinGroupChat(BaseGroupChat):
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.task import StopMessageTermination
from autogen_agentchat.task import TextMentionTermination
async def main() -> None:
@@ -103,8 +105,9 @@ class RoundRobinGroupChat(BaseGroupChat):
agent1 = AssistantAgent("Assistant1", model_client=model_client)
agent2 = AssistantAgent("Assistant2", model_client=model_client)
team = RoundRobinGroupChat([agent1, agent2])
stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination())
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
stream = team.run_stream("Tell me some jokes.")
async for message in stream:
print(message)
@@ -112,10 +115,13 @@ class RoundRobinGroupChat(BaseGroupChat):
asyncio.run(main())
"""
def __init__(self, participants: List[ChatAgent]):
def __init__(
self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None
) -> None:
super().__init__(
participants,
group_chat_manager_class=RoundRobinGroupChatManager,
termination_condition=termination_condition,
)
def _create_group_chat_manager_factory(

View File

@@ -169,6 +169,8 @@ class SelectorGroupChat(BaseGroupChat):
must have unique names and at least two participants.
model_client (ChatCompletionClient): The ChatCompletion model client used
to select the next speaker.
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
selector_prompt (str, optional): The prompt template to use for selecting the next speaker.
Must contain '{roles}', '{participants}', and '{history}' to be filled in.
allow_repeated_speaker (bool, optional): Whether to allow the same speaker to be selected
@@ -187,10 +189,11 @@ class SelectorGroupChat(BaseGroupChat):
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.task import StopMessageTermination
from autogen_agentchat.task import TextMentionTermination
async def main() -> None:
@@ -223,20 +226,24 @@ class SelectorGroupChat(BaseGroupChat):
tools=[lookup_flight],
description="Helps with flight booking.",
)
team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client)
stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination())
termination = TextMentionTermination("TERMINATE")
team = SelectorGroupChat(
[travel_advisor, hotel_agent, flight_agent],
model_client=model_client,
termination_condition=termination,
)
stream = team.run_stream("Book a 3-day trip to new york.")
async for message in stream:
print(message)
import asyncio
asyncio.run(main())
A team with a custom selector function:
.. code-block:: python
import asyncio
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
@@ -273,15 +280,19 @@ class SelectorGroupChat(BaseGroupChat):
return "Agent2"
return None
team = SelectorGroupChat([agent1, agent2], model_client=model_client, selector_func=selector_func)
termination = TextMentionTermination("Correct!")
team = SelectorGroupChat(
[agent1, agent2],
model_client=model_client,
selector_func=selector_func,
termination_condition=termination,
)
stream = team.run_stream("What is 1 + 1?", termination_condition=TextMentionTermination("Correct!"))
stream = team.run_stream("What is 1 + 1?")
async for message in stream:
print(message)
import asyncio
asyncio.run(main())
"""
@@ -290,6 +301,7 @@ class SelectorGroupChat(BaseGroupChat):
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
termination_condition: TerminationCondition | None = None,
selector_prompt: str = """You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
@@ -301,7 +313,9 @@ Read the above conversation. Then select the next role from {participants} to pl
allow_repeated_speaker: bool = False,
selector_func: Callable[[Sequence[ChatMessage]], str | None] | None = None,
):
super().__init__(participants, group_chat_manager_class=SelectorGroupChatManager)
super().__init__(
participants, group_chat_manager_class=SelectorGroupChatManager, termination_condition=termination_condition
)
# Validate the participants.
if len(participants) < 2:
raise ValueError("At least two participants are required for SelectorGroupChat.")

View File

@@ -56,6 +56,8 @@ class Swarm(BaseGroupChat):
Args:
participants (List[ChatAgent]): The agents participating in the group chat. The first agent in the list is the initial speaker.
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
Examples:
@@ -81,9 +83,10 @@ class Swarm(BaseGroupChat):
"Bob", model_client=model_client, system_message="You are Bob and your birthday is on 1st January."
)
team = Swarm([agent1, agent2])
termination = MaxMessageTermination(3)
team = Swarm([agent1, agent2], termination_condition=termination)
stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3))
stream = team.run_stream("What is bob's birthday?")
async for message in stream:
print(message)
@@ -91,8 +94,12 @@ class Swarm(BaseGroupChat):
asyncio.run(main())
"""
def __init__(self, participants: List[ChatAgent]):
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
def __init__(
self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None
) -> None:
super().__init__(
participants, group_chat_manager_class=SwarmGroupChatManager, termination_condition=termination_condition
)
# The first participant must be able to produce handoff messages.
first_participant = self._participants[0]
if HandoffMessage not in first_participant.produced_message_types: