Initial impl of topics and subscriptions (#350)

* initial impl of topics and subscriptions

* Update python/src/agnext/core/_agent_runtime.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* add topic in context

* migrate

* migrate code for topics

* migrate team one

* edit notebooks

* formatting

* fix imports

* Build proto

* Fix circular import

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-20 14:41:24 -04:00
committed by GitHub
parent 4ba7e84721
commit e1a823fb6d
71 changed files with 685 additions and 495 deletions

View File

@@ -41,6 +41,7 @@ from typing import Dict, List, Tuple
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import (
AssistantMessage,
ChatCompletionClient,
@@ -48,6 +49,7 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import TopicId
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -163,9 +165,12 @@ class MathSolver(TypeRoutedAgent):
answer = match.group(1)
# Increment the counter.
self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1
assert ctx.topic_id is not None
if self._counters[message.session_id] == self._max_round:
# If the counter reaches the maximum round, publishes a final response.
await self.publish_message(FinalSolverResponse(answer=answer, session_id=message.session_id))
await self.publish_message(
FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id
)
else:
# Publish intermediate response.
await self.publish_message(
@@ -175,7 +180,8 @@ class MathSolver(TypeRoutedAgent):
answer=answer,
session_id=message.session_id,
round=self._counters[message.session_id],
)
),
topic_id=ctx.topic_id,
)
@@ -193,7 +199,10 @@ class MathAggregator(TypeRoutedAgent):
"in the form of {{answer}}, at the end of your response."
)
session_id = str(uuid.uuid4())
await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content))
assert ctx.topic_id is not None
await self.publish_message(
SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id
)
@message_handler
async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None:
@@ -203,7 +212,8 @@ class MathAggregator(TypeRoutedAgent):
answers = [resp.answer for resp in self._responses[message.session_id]]
majority_answer = max(set(answers), key=answers.count)
# Publish the aggregated response.
await self.publish_message(Answer(content=majority_answer))
assert ctx.topic_id is not None
await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id)
# Clear the responses.
self._responses.pop(message.session_id)
print(f"Aggregated answer: {majority_answer}")
@@ -223,6 +233,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver1"))
await runtime.register(
"MathSolver2",
lambda: MathSolver(
@@ -231,6 +242,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver2"))
await runtime.register(
"MathSolver3",
lambda: MathSolver(
@@ -239,6 +251,7 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver3"))
await runtime.register(
"MathSolver4",
lambda: MathSolver(
@@ -247,13 +260,14 @@ async def main(question: str) -> None:
max_round=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "MathSolver4"))
# Register the aggregator agent.
await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4))
run_context = runtime.start()
# Send a math problem to the aggregator agent.
await runtime.publish_message(Question(content=question), namespace="default")
await runtime.publish_message(Question(content=question), topic_id=TopicId("default", "default"))
await run_context.stop_when_idle()