mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Initial impl of topics and subscriptions (#350)
* initial impl of topics and subscriptions * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * add topic in context * migrate * migrate code for topics * migrate team one * edit notebooks * formatting * fix imports * Build proto * Fix circular import --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@@ -48,6 +49,7 @@ from agnext.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.core import TopicId
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -163,9 +165,12 @@ class MathSolver(TypeRoutedAgent):
|
||||
answer = match.group(1)
|
||||
# Increment the counter.
|
||||
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
|
||||
assert ctx.topic_id is not None
|
||||
if self._counters[message.session_id] == self._max_round:
|
||||
# If the counter reaches the maximum round, publishes a final response.
|
||||
await self.publish_message(FinalSolverResponse(answer=answer, session_id=message.session_id))
|
||||
await self.publish_message(
|
||||
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
|
||||
)
|
||||
else:
|
||||
# Publish intermediate response.
|
||||
await self.publish_message(
|
||||
@@ -175,7 +180,8 @@ class MathSolver(TypeRoutedAgent):
|
||||
answer=answer,
|
||||
session_id=message.session_id,
|
||||
round=self._counters[message.session_id],
|
||||
)
|
||||
),
|
||||
topic_id=ctx.topic_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -193,7 +199,10 @@ class MathAggregator(TypeRoutedAgent):
|
||||
"in the form of {{answer}}, at the end of your response."
|
||||
)
|
||||
session_id = str(uuid.uuid4())
|
||||
await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(
|
||||
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
|
||||
)
|
||||
|
||||
@message_handler
|
||||
async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None:
|
||||
@@ -203,7 +212,8 @@ class MathAggregator(TypeRoutedAgent):
|
||||
answers = [resp.answer for resp in self._responses[message.session_id]]
|
||||
majority_answer = max(set(answers), key=answers.count)
|
||||
# Publish the aggregated response.
|
||||
await self.publish_message(Answer(content=majority_answer))
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
|
||||
# Clear the responses.
|
||||
self._responses.pop(message.session_id)
|
||||
print(f"Aggregated answer: {majority_answer}")
|
||||
@@ -223,6 +233,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver1"))
|
||||
await runtime.register(
|
||||
"MathSolver2",
|
||||
lambda: MathSolver(
|
||||
@@ -231,6 +242,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver2"))
|
||||
await runtime.register(
|
||||
"MathSolver3",
|
||||
lambda: MathSolver(
|
||||
@@ -239,6 +251,7 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver3"))
|
||||
await runtime.register(
|
||||
"MathSolver4",
|
||||
lambda: MathSolver(
|
||||
@@ -247,13 +260,14 @@ async def main(question: str) -> None:
|
||||
max_round=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "MathSolver4"))
|
||||
# Register the aggregator agent.
|
||||
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
# Send a math problem to the aggregator agent.
|
||||
await runtime.publish_message(Question(content=question), namespace="default")
|
||||
await runtime.publish_message(Question(content=question), topic_id=TopicId("default", "default"))
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user