mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-14 22:08:09 -05:00
Migrate to using default sub/topic (#403)
This commit is contained in:
@@ -21,7 +21,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import RoutedAgent, message_handler
|
||||
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
|
||||
from agnext.components.models import (
|
||||
@@ -102,12 +102,11 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
AssistantMessage(content=response.content, source=self.metadata["type"])
|
||||
)
|
||||
|
||||
assert ctx.topic_id is not None
|
||||
# Publish the code execution task.
|
||||
await self.publish_message(
|
||||
CodeExecutionTask(content=response.content, session_id=session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
@message_handler
|
||||
@@ -124,11 +123,10 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
|
||||
if "TERMINATE" in response.content:
|
||||
# If the task is completed, publish a message with the completion content.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
TaskCompletion(content=response.content),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
print("--------------------")
|
||||
print("Task completed:")
|
||||
@@ -136,11 +134,10 @@ Reply "TERMINATE" in the end when everything is done."""
|
||||
return
|
||||
|
||||
# Publish the code execution task.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTask(content=response.content, session_id=message.session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
|
||||
@@ -157,13 +154,12 @@ class Executor(RoutedAgent):
|
||||
code_blocks = self._extract_code_blocks(message.content)
|
||||
if not code_blocks:
|
||||
# If no code block is found, publish a message with an error.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(
|
||||
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
|
||||
),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
return
|
||||
# Execute code blocks.
|
||||
@@ -171,11 +167,10 @@ class Executor(RoutedAgent):
|
||||
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
|
||||
)
|
||||
# Publish the code execution result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:
|
||||
|
||||
@@ -21,7 +21,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import RoutedAgent, message_handler
|
||||
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
@@ -112,14 +112,13 @@ Please review the code and provide feedback.
|
||||
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
|
||||
approved = review["approval"].lower().strip() == "approve"
|
||||
# Publish the review result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeReviewResult(
|
||||
review=review_text,
|
||||
approved=approved,
|
||||
session_id=message.session_id,
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
|
||||
@@ -183,10 +182,9 @@ Code: <Your code>
|
||||
# Store the code review task in the session memory.
|
||||
self._session_memory[session_id].append(code_review_task)
|
||||
# Publish a code review task.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
code_review_task,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
@message_handler
|
||||
@@ -201,14 +199,13 @@ Code: <Your code>
|
||||
# Check if the code is approved.
|
||||
if message.approved:
|
||||
# Publish the code writing result.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
CodeWritingResult(
|
||||
code=review_request.code,
|
||||
task=review_request.code_writing_task,
|
||||
review=message.review,
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
print("Code Writing Result:")
|
||||
print("-" * 80)
|
||||
@@ -247,10 +244,9 @@ Code: <Your code>
|
||||
# Store the code review task in the session memory.
|
||||
self._session_memory[message.session_id].append(code_review_task)
|
||||
# Publish a new code review task.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
code_review_task,
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import RoutedAgent, message_handler
|
||||
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@@ -69,8 +69,7 @@ class RoundRobinGroupChatManager(RoutedAgent):
|
||||
self._round_count += 1
|
||||
if self._round_count > self._num_rounds * len(self._participants):
|
||||
# End the conversation after the specified number of rounds.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Termination(), ctx.topic_id)
|
||||
await self.publish_message(Termination(), DefaultTopicId())
|
||||
return
|
||||
# Send a request to speak message to the selected speaker.
|
||||
await self.send_message(RequestToSpeak(), speaker)
|
||||
@@ -107,8 +106,7 @@ class GroupChatParticipant(RoutedAgent):
|
||||
assert isinstance(response.content, str)
|
||||
speech = Message(content=response.content, source=self.metadata["type"])
|
||||
self._memory.append(speech)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(speech, topic_id=ctx.topic_id)
|
||||
await self.publish_message(speech, topic_id=DefaultTopicId())
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
||||
@@ -15,7 +15,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import RoutedAgent, message_handler
|
||||
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
|
||||
from agnext.core import MessageContext
|
||||
@@ -68,8 +68,7 @@ class ReferenceAgent(RoutedAgent):
|
||||
response = await self._model_client.create(self._system_messages + [task_message])
|
||||
assert isinstance(response.content, str)
|
||||
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
await self.publish_message(task_result, topic_id=DefaultTopicId())
|
||||
|
||||
|
||||
class AggregatorAgent(RoutedAgent):
|
||||
@@ -93,8 +92,7 @@ class AggregatorAgent(RoutedAgent):
|
||||
"""Handle a task message. This method publishes the task to the reference agents."""
|
||||
session_id = str(uuid.uuid4())
|
||||
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(ref_task, topic_id=ctx.topic_id)
|
||||
await self.publish_message(ref_task, topic_id=DefaultTopicId())
|
||||
|
||||
@message_handler
|
||||
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
|
||||
@@ -108,8 +106,7 @@ class AggregatorAgent(RoutedAgent):
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
task_result = AggregatorTaskResult(result=response.content)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
await self.publish_message(task_result, topic_id=DefaultTopicId())
|
||||
self._session_results.pop(message.session_id)
|
||||
print(f"Aggregator result: {response.content}")
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import RoutedAgent, message_handler
|
||||
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
@@ -165,11 +165,10 @@ class MathSolver(RoutedAgent):
|
||||
answer = match.group(1)
|
||||
# Increment the counter.
|
||||
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
|
||||
assert ctx.topic_id is not None
|
||||
if self._counters[message.session_id] == self._max_round:
|
||||
# If the counter reaches the maximum round, publishes a final response.
|
||||
await self.publish_message(
|
||||
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
|
||||
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=DefaultTopicId()
|
||||
)
|
||||
else:
|
||||
# Publish intermediate response.
|
||||
@@ -181,7 +180,7 @@ class MathSolver(RoutedAgent):
|
||||
session_id=message.session_id,
|
||||
round=self._counters[message.session_id],
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
topic_id=DefaultTopicId(),
|
||||
)
|
||||
|
||||
|
||||
@@ -199,9 +198,8 @@ class MathAggregator(RoutedAgent):
|
||||
"in the form of {{answer}}, at the end of your response."
|
||||
)
|
||||
session_id = str(uuid.uuid4())
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
|
||||
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=DefaultTopicId()
|
||||
)
|
||||
|
||||
@message_handler
|
||||
@@ -212,8 +210,7 @@ class MathAggregator(RoutedAgent):
|
||||
answers = [resp.answer for resp in self._responses[message.session_id]]
|
||||
majority_answer = max(set(answers), key=answers.count)
|
||||
# Publish the aggregated response.
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
|
||||
await self.publish_message(Answer(content=majority_answer), topic_id=DefaultTopicId())
|
||||
# Clear the responses.
|
||||
self._responses.pop(message.session_id)
|
||||
print(f"Aggregated answer: {majority_answer}")
|
||||
|
||||
Reference in New Issue
Block a user