Migrate to using default sub/topic (#403)

This commit is contained in:
Jack Gerrits
2024-08-26 10:30:28 -04:00
committed by GitHub
parent d7ae2038fb
commit dbb35fc335
23 changed files with 491 additions and 526 deletions

View File

@@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor
from agnext.components.models import (
@@ -102,12 +102,11 @@ Reply "TERMINATE" in the end when everything is done."""
AssistantMessage(content=response.content, source=self.metadata["type"])
)
assert ctx.topic_id is not None
# Publish the code execution task.
await self.publish_message(
CodeExecutionTask(content=response.content, session_id=session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@message_handler
@@ -124,11 +123,10 @@ Reply "TERMINATE" in the end when everything is done."""
if "TERMINATE" in response.content:
# If the task is completed, publish a message with the completion content.
assert ctx.topic_id is not None
await self.publish_message(
TaskCompletion(content=response.content),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
print("--------------------")
print("Task completed:")
@@ -136,11 +134,10 @@ Reply "TERMINATE" in the end when everything is done."""
return
# Publish the code execution task.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTask(content=response.content, session_id=message.session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@@ -157,13 +154,12 @@ class Executor(RoutedAgent):
code_blocks = self._extract_code_blocks(message.content)
if not code_blocks:
# If no code block is found, publish a message with an error.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTaskResult(
output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id
),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
return
# Execute code blocks.
@@ -171,11 +167,10 @@ class Executor(RoutedAgent):
code_blocks=code_blocks, cancellation_token=ctx.cancellation_token
)
# Publish the code execution result.
assert ctx.topic_id is not None
await self.publish_message(
CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id),
cancellation_token=ctx.cancellation_token,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]:

View File

@@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import Dict, List, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
@@ -112,14 +112,13 @@ Please review the code and provide feedback.
review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()])
approved = review["approval"].lower().strip() == "approve"
# Publish the review result.
assert ctx.topic_id is not None
await self.publish_message(
CodeReviewResult(
review=review_text,
approved=approved,
session_id=message.session_id,
),
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@@ -183,10 +182,9 @@ Code: <Your code>
# Store the code review task in the session memory.
self._session_memory[session_id].append(code_review_task)
# Publish a code review task.
assert ctx.topic_id is not None
await self.publish_message(
code_review_task,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@message_handler
@@ -201,14 +199,13 @@ Code: <Your code>
# Check if the code is approved.
if message.approved:
# Publish the code writing result.
assert ctx.topic_id is not None
await self.publish_message(
CodeWritingResult(
code=review_request.code,
task=review_request.code_writing_task,
review=message.review,
),
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
print("Code Writing Result:")
print("-" * 80)
@@ -247,10 +244,9 @@ Code: <Your code>
# Store the code review task in the session memory.
self._session_memory[message.session_id].append(code_review_task)
# Publish a new code review task.
assert ctx.topic_id is not None
await self.publish_message(
code_review_task,
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
def _extract_code_block(self, markdown_text: str) -> Union[str, None]:

View File

@@ -18,7 +18,7 @@ from dataclasses import dataclass
from typing import List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
@@ -69,8 +69,7 @@ class RoundRobinGroupChatManager(RoutedAgent):
self._round_count += 1
if self._round_count > self._num_rounds * len(self._participants):
# End the conversation after the specified number of rounds.
assert ctx.topic_id is not None
await self.publish_message(Termination(), ctx.topic_id)
await self.publish_message(Termination(), DefaultTopicId())
return
# Send a request to speak message to the selected speaker.
await self.send_message(RequestToSpeak(), speaker)
@@ -107,8 +106,7 @@ class GroupChatParticipant(RoutedAgent):
assert isinstance(response.content, str)
speech = Message(content=response.content, source=self.metadata["type"])
self._memory.append(speech)
assert ctx.topic_id is not None
await self.publish_message(speech, topic_id=ctx.topic_id)
await self.publish_message(speech, topic_id=DefaultTopicId())
async def main() -> None:

View File

@@ -15,7 +15,7 @@ from dataclasses import dataclass
from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
from agnext.core import MessageContext
@@ -68,8 +68,7 @@ class ReferenceAgent(RoutedAgent):
response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str)
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
await self.publish_message(task_result, topic_id=DefaultTopicId())
class AggregatorAgent(RoutedAgent):
@@ -93,8 +92,7 @@ class AggregatorAgent(RoutedAgent):
"""Handle a task message. This method publishes the task to the reference agents."""
session_id = str(uuid.uuid4())
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
assert ctx.topic_id is not None
await self.publish_message(ref_task, topic_id=ctx.topic_id)
await self.publish_message(ref_task, topic_id=DefaultTopicId())
@message_handler
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
@@ -108,8 +106,7 @@ class AggregatorAgent(RoutedAgent):
)
assert isinstance(response.content, str)
task_result = AggregatorTaskResult(result=response.content)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
await self.publish_message(task_result, topic_id=DefaultTopicId())
self._session_results.pop(message.session_id)
print(f"Aggregator result: {response.content}")

View File

@@ -40,7 +40,7 @@ from dataclasses import dataclass
from typing import Dict, List, Tuple
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import RoutedAgent, message_handler
from agnext.components import DefaultTopicId, RoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
@@ -165,11 +165,10 @@ class MathSolver(RoutedAgent):
answer = match.group(1)
# Increment the counter.
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
assert ctx.topic_id is not None
if self._counters[message.session_id] == self._max_round:
# If the counter reaches the maximum round, publishes a final response.
await self.publish_message(
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=DefaultTopicId()
)
else:
# Publish intermediate response.
@@ -181,7 +180,7 @@ class MathSolver(RoutedAgent):
session_id=message.session_id,
round=self._counters[message.session_id],
),
topic_id=ctx.topic_id,
topic_id=DefaultTopicId(),
)
@@ -199,9 +198,8 @@ class MathAggregator(RoutedAgent):
"in the form of {{answer}}, at the end of your response."
)
session_id = str(uuid.uuid4())
assert ctx.topic_id is not None
await self.publish_message(
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=DefaultTopicId()
)
@message_handler
@@ -212,8 +210,7 @@ class MathAggregator(RoutedAgent):
answers = [resp.answer for resp in self._responses[message.session_id]]
majority_answer = max(set(answers), key=answers.count)
# Publish the aggregated response.
assert ctx.topic_id is not None
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
await self.publish_message(Answer(content=majority_answer), topic_id=DefaultTopicId())
# Clear the responses.
self._responses.pop(message.session_id)
print(f"Aggregated answer: {majority_answer}")