mirror of
https://github.com/microsoft/autogen.git
synced 2026-05-13 03:00:55 -04:00
Initial impl of topics and subscriptions (#350)
* initial impl of topics and subscriptions * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * add topic in context * migrate * migrate code for topics * migrate team one * edit notebooks * formatting * fix imports * Build proto * Fix circular import --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
@@ -30,6 +31,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -100,10 +102,12 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
AssistantMessage(content=response.content, source=self.metadata["type"])
|
||||
)
|
||||
|
||||
assert ctx.topic_id is not None
|
||||
# Publish the code execution task.
|
||||
await self.publish_message(
|
||||
CodeExecutionTask(content=response.content, session_id=session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
@message_handler
|
||||
@@ -120,8 +124,11 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
|
||||
if "TERMINATE" in response.content:
|
||||
# If the task is completed, publish a message with the completion content.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
TaskCompletion(content=response.content), cancellation_token=ctx.cancellation_token
|
||||
TaskCompletion(content=response.content),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
print("--------------------")
|
||||
print("Task completed:")
|
||||
@@ -129,9 +136,11 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
return
|
||||
|
||||
# Publish the code execution task.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTask(content=response.content, session_id=message.session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -148,11 +157,13 @@ class Executor(TypeRoutedAgent):
|
||||
code_blocks = self._extract_code_blocks(message.content)
|
||||
if not code_blocks:
|
||||
# If no code block is found, publish a message with an error.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(
|
||||
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
|
||||
),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
return
|
||||
# Execute code blocks.
|
||||
@@ -160,9 +171,11 @@ class Executor(TypeRoutedAgent):
|
||||
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
# Publish the code execution result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
|
||||
@@ -185,10 +198,12 @@ async def main(task: str, temp_dir: str) -> None:
|
||||
"coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"))
|
||||
)
|
||||
await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)))
|
||||
await runtime.add_subscription(TypeSubscription("default", "coder"))
|
||||
await runtime.add_subscription(TypeSubscription("default", "executor"))
|
||||
run_context = runtime.start()
|
||||
|
||||
# Publish the task message.
|
||||
await runtime.publish_message(TaskMessage(content=task), namespace="default")
|
||||
await runtime.publish_message(TaskMessage(content=task), topic_id=TopicId("default", "default"))
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Dict, List, Union
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@@ -29,6 +30,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -110,12 +112,14 @@ Please review the code and provide feedback.
|
||||
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
|
||||
approved = review["approval"].lower().strip() == "approve"
|
||||
# Publish the review result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeReviewResult(
|
||||
review=review_text,
|
||||
approved=approved,
|
||||
session_id=message.session_id,
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -179,7 +183,11 @@ Code: <Your code>
|
||||
# Store the code review task in the session memory.
|
||||
self._session_memory[session_id].append(code_review_task)
|
||||
# Publish a code review task.
|
||||
await self.publish_message(code_review_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
code_review_task,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
@message_handler
|
||||
async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:
|
||||
@@ -193,12 +201,14 @@ Code: <Your code>
|
||||
# Check if the code is approved.
|
||||
if message.approved:
|
||||
# Publish the code writing result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeWritingResult(
|
||||
code=review_request.code,
|
||||
task=review_request.code_writing_task,
|
||||
review=message.review,
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
print("Code Writing Result:")
|
||||
print("-" * 80)
|
||||
@@ -237,7 +247,11 @@ Code: <Your code>
|
||||
# Store the code review task in the session memory.
|
||||
self._session_memory[message.session_id].append(code_review_task)
|
||||
# Publish a new code review task.
|
||||
await self.publish_message(code_review_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
code_review_task,
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:
|
||||
pattern = r"```(\w+)\n(.*?)\n```"
|
||||
@@ -258,6 +272,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReviewerAgent"))
|
||||
await runtime.register(
|
||||
"CoderAgent",
|
||||
lambda: CoderAgent(
|
||||
@@ -265,12 +280,13 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "CoderAgent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(
|
||||
message=CodeWritingTask(
|
||||
task="Write a function to find the directory with the largest number of files using multi-processing."
|
||||
),
|
||||
namespace="default",
|
||||
topic_id=TopicId("default", "default"),
|
||||
)
|
||||
|
||||
# Keep processing messages until idle.
|
||||
|
||||
@@ -26,7 +26,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import AgentId
|
||||
from agnext.core import AgentId, AgentInstantiationContext, TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -69,7 +69,8 @@ class RoundRobinGroupChatManager(TypeRoutedAgent):
|
||||
self._round_count += 1
|
||||
if self._round_count > self._num_rounds * len(self._participants):
|
||||
# End the conversation after the specified number of rounds.
|
||||
await self.publish_message(Termination())
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Termination(), ctx.topic_id)
|
||||
return
|
||||
# Send a request to speak message to the selected speaker.
|
||||
await self.send_message(RequestToSpeak(), speaker)
|
||||
@@ -104,9 +105,10 @@ class GroupChatParticipant(TypeRoutedAgent):
|
||||
llm_messages.append(UserMessage(content=m.content, source=m.source))
|
||||
response = await self._model_client.create(self._system_messages + llm_messages)
|
||||
assert isinstance(response.content, str)
|
||||
speach = Message(content=response.content, source=self.metadata["type"])
|
||||
self._memory.append(speach)
|
||||
await self.publish_message(speach)
|
||||
speech = Message(content=response.content, source=self.metadata["type"])
|
||||
self._memory.append(speech)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(speech, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@@ -114,7 +116,7 @@ async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Register the participants.
|
||||
agent1 = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"DataScientist",
|
||||
lambda: GroupChatParticipant(
|
||||
description="A data scientist",
|
||||
@@ -122,7 +124,8 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
agent2 = await runtime.register_and_get(
|
||||
|
||||
await runtime.register(
|
||||
"Engineer",
|
||||
lambda: GroupChatParticipant(
|
||||
description="An engineer",
|
||||
@@ -130,7 +133,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
|
||||
),
|
||||
)
|
||||
agent3 = await runtime.register_and_get(
|
||||
await runtime.register(
|
||||
"Artist",
|
||||
lambda: GroupChatParticipant(
|
||||
description="An artist",
|
||||
@@ -144,7 +147,11 @@ async def main() -> None:
|
||||
"GroupChatManager",
|
||||
lambda: RoundRobinGroupChatManager(
|
||||
description="A group chat manager",
|
||||
participants=[agent1, agent2, agent3],
|
||||
participants=[
|
||||
AgentId("DataScientist", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Engineer", AgentInstantiationContext.current_agent_id().key),
|
||||
AgentId("Artist", AgentInstantiationContext.current_agent_id().key),
|
||||
],
|
||||
num_rounds=3,
|
||||
),
|
||||
)
|
||||
@@ -153,7 +160,9 @@ async def main() -> None:
|
||||
run_context = runtime.start()
|
||||
|
||||
# Start the conversation.
|
||||
await runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default")
|
||||
await runtime.publish_message(
|
||||
Message(content="Hello, everyone!", source="Moderator"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
@@ -16,11 +16,13 @@ from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
|
||||
from agnext.core import MessageContext
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from agnext.core import TopicId
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
|
||||
@@ -66,7 +68,8 @@ class ReferenceAgent(TypeRoutedAgent):
|
||||
response = await self._model_client.create(self._system_messages + [task_message])
|
||||
assert isinstance(response.content, str)
|
||||
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
class AggregatorAgent(TypeRoutedAgent):
|
||||
@@ -90,7 +93,8 @@ class AggregatorAgent(TypeRoutedAgent):
|
||||
"""Handle a task message. This method publishes the task to the reference agents."""
|
||||
session_id = str(uuid.uuid4())
|
||||
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
|
||||
await self.publish_message(ref_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(ref_task, topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
|
||||
@@ -104,7 +108,8 @@ class AggregatorAgent(TypeRoutedAgent):
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
task_result = AggregatorTaskResult(result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
self._session_results.pop(message.session_id)
|
||||
print(f"Aggregator result: {response.content}")
|
||||
|
||||
@@ -120,6 +125,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.1),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent1"))
|
||||
await runtime.register(
|
||||
"ReferenceAgent2",
|
||||
lambda: ReferenceAgent(
|
||||
@@ -128,6 +134,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.5),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent2"))
|
||||
await runtime.register(
|
||||
"ReferenceAgent3",
|
||||
lambda: ReferenceAgent(
|
||||
@@ -136,6 +143,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=1.0),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent3"))
|
||||
await runtime.register(
|
||||
"AggregatorAgent",
|
||||
lambda: AggregatorAgent(
|
||||
@@ -149,8 +157,11 @@ async def main() -> None:
|
||||
num_references=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "AggregatorAgent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
|
||||
await runtime.publish_message(
|
||||
AggregatorTask(task="What are something fun to do in SF?"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
|
||||
# Keep processing messages.
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
@@ -41,6 +41,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@@ -48,6 +49,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -163,9 +165,12 @@ class MathSolver(TypeRoutedAgent):
|
||||
answer = match.group(1)
|
||||
# Increment the counter.
|
||||
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
|
||||
assert ctx.topic_id is not None
|
||||
if self._counters[message.session_id] == self._max_round:
|
||||
# If the counter reaches the maximum round, publishes a final response.
|
||||
await self.publish_message(FinalSolverResponse(answer=answer, session_id=message.session_id))
|
||||
await self.publish_message(
|
||||
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
|
||||
)
|
||||
else:
|
||||
# Publish intermediate response.
|
||||
await self.publish_message(
|
||||
@@ -175,7 +180,8 @@ class MathSolver(TypeRoutedAgent):
|
||||
answer=answer,
|
||||
session_id=message.session_id,
|
||||
round=self._counters[message.session_id],
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -193,7 +199,10 @@ class MathAggregator(TypeRoutedAgent):
|
||||
"in the form of {{answer}}, at the end of your response."
|
||||
)
|
||||
session_id = str(uuid.uuid4())
|
||||
await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
|
||||
)
|
||||
|
||||
@message_handler
|
||||
async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None:
|
||||
@@ -203,7 +212,8 @@ class MathAggregator(TypeRoutedAgent):
|
||||
answers = [resp.answer for resp in self._responses[message.session_id]]
|
||||
majority_answer = max(set(answers), key=answers.count)
|
||||
# Publish the aggregated response.
|
||||
await self.publish_message(Answer(content=majority_answer))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
|
||||
# Clear the responses.
|
||||
self._responses.pop(message.session_id)
|
||||
print(f"Aggregated answer: {majority_answer}")
|
||||
@@ -223,6 +233,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver1"))
|
||||
await runtime.register(
|
||||
"MathSolver2",
|
||||
lambda: MathSolver(
|
||||
@@ -231,6 +242,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver2"))
|
||||
await runtime.register(
|
||||
"MathSolver3",
|
||||
lambda: MathSolver(
|
||||
@@ -239,6 +251,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver3"))
|
||||
await runtime.register(
|
||||
"MathSolver4",
|
||||
lambda: MathSolver(
|
||||
@@ -247,13 +260,14 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver4"))
|
||||
# Register the aggregator agent.
|
||||
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a math problem to the aggregator agent.
|
||||
await runtime.publish_message(Question(content=question), namespace="default")
|
||||
await runtime.publish_message(Question(content=question), topic_id=TopicId("default", "default"))
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user