Initial impl of topics and subscriptions (#350)

* initial impl of topics and subscriptions

* Update python/src/agnext/core/_agent_runtime.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* add topic in context

* migrate

* migrate code for topics

* migrate team one

* edit notebooks

* formatting

* fix imports

* Build proto

* Fix circular import

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-20 14:41:24 -04:00
committed by GitHub
parent 4ba7e84721
commit e1a823fb6d
71 changed files with 685 additions and 495 deletions

View File

@@ -16,11 +16,13 @@ from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components._type_subscription import TypeSubscription
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
from agnext.core import MessageContext
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import TopicId
from common.utils import get_chat_completion_client_from_envs
@@ -66,7 +68,8 @@ class ReferenceAgent(TypeRoutedAgent):
response = await self._model_client.create(self._system_messages + [task_message])
assert isinstance(response.content, str)
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
await self.publish_message(task_result)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
class AggregatorAgent(TypeRoutedAgent):
@@ -90,7 +93,8 @@ class AggregatorAgent(TypeRoutedAgent):
"""Handle a task message. This method publishes the task to the reference agents."""
session_id = str(uuid.uuid4())
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
await self.publish_message(ref_task)
assert ctx.topic_id is not None
await self.publish_message(ref_task, topic_id=ctx.topic_id)
@message_handler
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
@@ -104,7 +108,8 @@ class AggregatorAgent(TypeRoutedAgent):
)
assert isinstance(response.content, str)
task_result = AggregatorTaskResult(result=response.content)
await self.publish_message(task_result)
assert ctx.topic_id is not None
await self.publish_message(task_result, topic_id=ctx.topic_id)
self._session_results.pop(message.session_id)
print(f"Aggregator result: {response.content}")
@@ -120,6 +125,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.1),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent1"))
await runtime.register(
"ReferenceAgent2",
lambda: ReferenceAgent(
@@ -128,6 +134,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.5),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent2"))
await runtime.register(
"ReferenceAgent3",
lambda: ReferenceAgent(
@@ -136,6 +143,7 @@ async def main() -> None:
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=1.0),
),
)
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent3"))
await runtime.register(
"AggregatorAgent",
lambda: AggregatorAgent(
@@ -149,8 +157,11 @@ async def main() -> None:
num_references=3,
),
)
await runtime.add_subscription(TypeSubscription("default", "AggregatorAgent"))
run_context = runtime.start()
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
await runtime.publish_message(
AggregatorTask(task="What are something fun to do in SF?"), topic_id=TopicId("default", "default")
)
# Keep processing messages.
await run_context.stop_when_idle()