mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-16 15:15:43 -05:00
Initial impl of topics and subscriptions (#350)
* initial impl of topics and subscriptions * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * add topic in context * migrate * migrate code for topics * migrate team one * edit notebooks * formatting * fix imports * Build proto * Fix circular import --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
@@ -16,11 +16,13 @@ from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components._type_subscription import TypeSubscription
|
||||
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
|
||||
from agnext.core import MessageContext
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from agnext.core import TopicId
|
||||
from common.utils import get_chat_completion_client_from_envs
|
||||
|
||||
|
||||
@@ -66,7 +68,8 @@ class ReferenceAgent(TypeRoutedAgent):
|
||||
response = await self._model_client.create(self._system_messages + [task_message])
|
||||
assert isinstance(response.content, str)
|
||||
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
|
||||
|
||||
class AggregatorAgent(TypeRoutedAgent):
|
||||
@@ -90,7 +93,8 @@ class AggregatorAgent(TypeRoutedAgent):
|
||||
"""Handle a task message. This method publishes the task to the reference agents."""
|
||||
session_id = str(uuid.uuid4())
|
||||
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
|
||||
await self.publish_message(ref_task)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(ref_task, topic_id=ctx.topic_id)
|
||||
|
||||
@message_handler
|
||||
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
|
||||
@@ -104,7 +108,8 @@ class AggregatorAgent(TypeRoutedAgent):
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
task_result = AggregatorTaskResult(result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
assert ctx.topic_id is not None
|
||||
await self.publish_message(task_result, topic_id=ctx.topic_id)
|
||||
self._session_results.pop(message.session_id)
|
||||
print(f"Aggregator result: {response.content}")
|
||||
|
||||
@@ -120,6 +125,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.1),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent1"))
|
||||
await runtime.register(
|
||||
"ReferenceAgent2",
|
||||
lambda: ReferenceAgent(
|
||||
@@ -128,6 +134,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.5),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent2"))
|
||||
await runtime.register(
|
||||
"ReferenceAgent3",
|
||||
lambda: ReferenceAgent(
|
||||
@@ -136,6 +143,7 @@ async def main() -> None:
|
||||
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=1.0),
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent3"))
|
||||
await runtime.register(
|
||||
"AggregatorAgent",
|
||||
lambda: AggregatorAgent(
|
||||
@@ -149,8 +157,11 @@ async def main() -> None:
|
||||
num_references=3,
|
||||
),
|
||||
)
|
||||
await runtime.add_subscription(TypeSubscription("default", "AggregatorAgent"))
|
||||
run_context = runtime.start()
|
||||
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
|
||||
await runtime.publish_message(
|
||||
AggregatorTask(task="What are something fun to do in SF?"), topic_id=TopicId("default", "default")
|
||||
)
|
||||
|
||||
# Keep processing messages.
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
Reference in New Issue
Block a user