From e1a823fb6d331f6be2dd5ae46aa66dcc0ad262ba Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 20 Aug 2024 14:41:24 -0400 Subject: [PATCH] Initial impl of topics and subscriptions (#350) * initial impl of topics and subscriptions * Update python/src/agnext/core/_agent_runtime.py Co-authored-by: Eric Zhu * add topic in context * migrate * migrate code for topics * migrate team one * edit notebooks * formatting * fix imports * Build proto * Fix circular import --------- Co-authored-by: Eric Zhu --- protos/agent_worker.proto | 10 +- .../docs/src/cookbook/langgraph-agent.ipynb | 7 +- .../docs/src/cookbook/llamaindex-agent.ipynb | 7 +- .../src/cookbook/openai-assistant-agent.ipynb | 7 +- .../agent-and-agent-runtime.ipynb | 6 +- .../message-and-communication.ipynb | 37 +++-- .../src/getting-started/model-clients.ipynb | 6 +- .../multi-agent-design-patterns.ipynb | 16 ++- python/docs/src/getting-started/tools.ipynb | 9 +- python/samples/byoa/langgraph_agent.py | 5 +- python/samples/byoa/llamaindex_agent.py | 5 +- .../common/agents/_chat_completion_agent.py | 3 +- .../common/agents/_image_generation_agent.py | 6 +- .../samples/common/agents/_oai_assistant.py | 3 +- python/samples/common/agents/_user_proxy.py | 3 +- python/samples/core/inner_outer_direct.py | 7 +- python/samples/core/one_agent_direct.py | 4 +- python/samples/core/two_agents_pub_sub.py | 16 ++- python/samples/demos/assistant.py | 29 ++-- python/samples/demos/chat_room.py | 16 ++- python/samples/demos/chess_game.py | 19 ++- python/samples/demos/illustrator_critics.py | 18 ++- python/samples/demos/software_consultancy.py | 29 ++-- python/samples/demos/utils.py | 5 +- python/samples/marketing-agents/app.py | 6 +- python/samples/marketing-agents/auditor.py | 5 +- .../marketing-agents/graphic_designer.py | 5 +- python/samples/marketing-agents/test_usage.py | 8 +- python/samples/marketing-agents/worker.py | 2 +- python/samples/patterns/coder_executor.py | 19 ++- python/samples/patterns/coder_reviewer.py | 26 +++- python/samples/patterns/group_chat.py | 29 ++-- python/samples/patterns/mixture_of_agents.py | 19 ++- python/samples/patterns/multi_agent_debate.py | 24 +++- python/samples/tool-use/coding_direct.py | 10 +- .../tool-use/coding_direct_with_intercept.py | 10 +- python/samples/tool-use/coding_pub_sub.py | 20 ++- python/samples/tool-use/custom_tool_direct.py | 9 +- python/samples/worker/run_worker_pub_sub.py | 22 ++- python/samples/worker/run_worker_rpc.py | 12 +- .../_single_threaded_agent_runtime.py | 114 +++++++-------- .../src/agnext/application/_worker_runtime.py | 134 +++++++----------- .../application/protos/agent_worker_pb2.py | 16 +-- .../application/protos/agent_worker_pb2.pyi | 19 +-- .../src/agnext/components/_closure_agent.py | 1 - .../agnext/components/_type_routed_agent.py | 6 +- .../agnext/components/_type_subscription.py | 1 - python/src/agnext/core/__init__.py | 3 +- python/src/agnext/core/_agent_metadata.py | 3 +- python/src/agnext/core/_agent_runtime.py | 97 ++++--------- python/src/agnext/core/_base_agent.py | 16 +-- python/src/agnext/core/_subscription.py | 19 ++- python/src/agnext/core/_topic.py | 2 +- python/teams/team-one/examples/example.py | 11 +- .../teams/team-one/examples/example_coder.py | 12 +- .../team-one/examples/example_file_surfer.py | 8 +- .../team-one/examples/example_reflexagents.py | 16 ++- .../team-one/examples/example_userproxy.py | 8 +- .../team-one/examples/example_websurfer.py | 7 +- .../src/team_one/agents/base_worker.py | 5 +- .../src/team_one/agents/orchestrator.py | 17 ++- .../src/team_one/agents/reflex_agents.py | 5 +- .../headless_web_surfer/test_web_surfer.py | 24 ++-- python/tests/test_cancellation.py | 47 +++--- python/tests/test_closure_agent.py | 14 +- python/tests/test_intervention.py | 23 +-- python/tests/test_runtime.py | 24 ++-- python/tests/test_serialization.py | 2 +- python/tests/test_state.py | 18 +-- python/tests/test_tool_agent.py | 4 +- python/tests/test_utils/__init__.py | 5 +- 71 files changed, 685 insertions(+), 495 deletions(-) diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 21bc410fd..9dd8f6996 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -2,7 +2,6 @@ syntax = "proto3"; package agents; -// TODO: update message AgentId { string name = 1; string namespace = 2; @@ -25,10 +24,11 @@ message RpcResponse { } message Event { - string namespace = 1; - string type = 2; - string data = 3; - map metadata = 4; + string topic_type = 1; + string topic_source = 2; + string data_type = 3; + string data = 4; + map metadata = 5; } message RegisterAgentType { diff --git a/python/docs/src/cookbook/langgraph-agent.ipynb b/python/docs/src/cookbook/langgraph-agent.ipynb index 27280a57b..888a10494 100644 --- a/python/docs/src/cookbook/langgraph-agent.ipynb +++ b/python/docs/src/cookbook/langgraph-agent.ipynb @@ -45,7 +45,7 @@ "\n", "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import MessageContext\n", + "from agnext.core import AgentId, MessageContext\n", "from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n", "from langchain_core.messages import HumanMessage, SystemMessage\n", "from langchain_core.tools import tool # pyright: ignore\n", @@ -195,7 +195,7 @@ "outputs": [], "source": [ "runtime = SingleThreadedAgentRuntime()\n", - "agent = await runtime.register_and_get(\n", + "await runtime.register(\n", " \"langgraph_tool_use_agent\",\n", " lambda: LangGraphToolUseAgent(\n", " \"Tool use agent\",\n", @@ -214,7 +214,8 @@ " # ),\n", " [get_weather],\n", " ),\n", - ")" + ")\n", + "agent = AgentId(\"langgraph_tool_use_agent\", key=\"default\")" ] }, { diff --git a/python/docs/src/cookbook/llamaindex-agent.ipynb b/python/docs/src/cookbook/llamaindex-agent.ipynb index d14dc261b..ae8c18e67 100644 --- a/python/docs/src/cookbook/llamaindex-agent.ipynb +++ b/python/docs/src/cookbook/llamaindex-agent.ipynb @@ -44,7 +44,7 @@ "\n", "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import MessageContext\n", + "from agnext.core import AgentId, MessageContext\n", "from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n", "from llama_index.core import Settings\n", "from llama_index.core.agent import ReActAgent\n", @@ -221,7 +221,7 @@ "outputs": [], "source": [ "runtime = SingleThreadedAgentRuntime()\n", - "agent = await runtime.register_and_get(\n", + "await runtime.register(\n", " \"chat_agent\",\n", " lambda: LlamaIndexAgent(\n", " description=\"Llama Index Agent\",\n", @@ -233,7 +233,8 @@ " verbose=True,\n", " ),\n", " ),\n", - ")" + ")\n", + "agent = AgentId(\"chat_agent\", \"default\")" ] }, { diff --git a/python/docs/src/cookbook/openai-assistant-agent.ipynb b/python/docs/src/cookbook/openai-assistant-agent.ipynb index feadcec11..a36c7fd05 100644 --- a/python/docs/src/cookbook/openai-assistant-agent.ipynb +++ b/python/docs/src/cookbook/openai-assistant-agent.ipynb @@ -105,7 +105,7 @@ "\n", "import aiofiles\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import MessageContext\n", + "from agnext.core import AgentId, MessageContext\n", "from openai import AsyncAssistantEventHandler, AsyncClient\n", "from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch\n", "\n", @@ -390,7 +390,7 @@ "from agnext.application import SingleThreadedAgentRuntime\n", "\n", "runtime = SingleThreadedAgentRuntime()\n", - "agent = await runtime.register_and_get(\n", + "await runtime.register(\n", " \"assistant\",\n", " lambda: OpenAIAssistantAgent(\n", " description=\"OpenAI Assistant Agent\",\n", @@ -399,7 +399,8 @@ " thread_id=thread.id,\n", " assistant_event_handler_factory=lambda: EventHandler(),\n", " ),\n", - ")" + ")\n", + "agent = AgentId(\"assistant\", \"default\")" ] }, { diff --git a/python/docs/src/getting-started/agent-and-agent-runtime.ipynb b/python/docs/src/getting-started/agent-and-agent-runtime.ipynb index 4717d4804..165b99690 100644 --- a/python/docs/src/getting-started/agent-and-agent-runtime.ipynb +++ b/python/docs/src/getting-started/agent-and-agent-runtime.ipynb @@ -57,7 +57,7 @@ "source": [ "from dataclasses import dataclass\n", "\n", - "from agnext.core import BaseAgent, MessageContext\n", + "from agnext.core import AgentId, BaseAgent, MessageContext\n", "\n", "\n", "@dataclass\n", @@ -67,7 +67,7 @@ "\n", "class MyAgent(BaseAgent):\n", " def __init__(self) -> None:\n", - " super().__init__(\"MyAgent\", subscriptions=[\"MyMessage\"])\n", + " super().__init__(\"MyAgent\")\n", "\n", " async def on_message(self, message: MyMessage, ctx: MessageContext) -> None:\n", " print(f\"Received message: {message.content}\")" @@ -133,7 +133,7 @@ } ], "source": [ - "agent_id = await runtime.get(\"my_agent\")\n", + "agent_id = AgentId(\"my_agent\", \"default\")\n", "run_context = runtime.start() # Start processing messages in the background.\n", "await runtime.send_message(MyMessage(content=\"Hello, World!\"), agent_id)\n", "await run_context.stop() # Stop processing messages in the background." diff --git a/python/docs/src/getting-started/message-and-communication.ipynb b/python/docs/src/getting-started/message-and-communication.ipynb index a9ae906cd..61e8bcd28 100644 --- a/python/docs/src/getting-started/message-and-communication.ipynb +++ b/python/docs/src/getting-started/message-and-communication.ipynb @@ -83,7 +83,7 @@ "source": [ "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import MessageContext\n", + "from agnext.core import AgentId, MessageContext\n", "\n", "\n", "class MyAgent(TypeRoutedAgent):\n", @@ -110,7 +110,8 @@ "outputs": [], "source": [ "runtime = SingleThreadedAgentRuntime()\n", - "agent = await runtime.register_and_get(\"my_agent\", lambda: MyAgent(\"My Agent\"))" + "await runtime.register(\"my_agent\", lambda: MyAgent(\"My Agent\"))\n", + "agent = AgentId(\"my_agent\", \"default\")" ] }, { @@ -185,7 +186,7 @@ "\n", "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import AgentId, MessageContext\n", + "from agnext.core import MessageContext\n", "\n", "\n", "@dataclass\n", @@ -200,9 +201,9 @@ "\n", "\n", "class OuterAgent(TypeRoutedAgent):\n", - " def __init__(self, description: str, inner_agent_id: AgentId):\n", + " def __init__(self, description: str, inner_agent_type: str):\n", " super().__init__(description)\n", - " self.inner_agent_id = inner_agent_id\n", + " self.inner_agent_id = AgentId(inner_agent_type, self.id.key)\n", "\n", " @message_handler\n", " async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n", @@ -238,9 +239,10 @@ ], "source": [ "runtime = SingleThreadedAgentRuntime()\n", - "inner = await runtime.register_and_get(\"inner_agent\", lambda: InnerAgent(\"InnerAgent\"))\n", - "outer = await runtime.register_and_get(\"outer_agent\", lambda: OuterAgent(\"OuterAgent\", inner))\n", + "await runtime.register(\"inner_agent\", lambda: InnerAgent(\"InnerAgent\"))\n", + "await runtime.register(\"outer_agent\", lambda: OuterAgent(\"OuterAgent\", \"InnerAgent\"))\n", "run_context = runtime.start()\n", + "outer = AgentId(\"outer_agent\", \"default\")\n", "await runtime.send_message(Message(content=\"Hello, World!\"), outer)\n", "await run_context.stop_when_idle()" ] @@ -294,14 +296,17 @@ "source": [ "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import MessageContext\n", + "from agnext.core import MessageContext, TopicId\n", "\n", "\n", "class BroadcastingAgent(TypeRoutedAgent):\n", " @message_handler\n", " async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n", " # Publish a message to all agents in the same namespace.\n", - " await self.publish_message(Message(f\"Publishing a message: {message.content}!\"))\n", + " assert ctx.topic_id is not None\n", + " await self.publish_message(\n", + " Message(f\"Publishing a message: {message.content}!\"), topic_id=TopicId(\"deafult\", self.id.key)\n", + " )\n", "\n", "\n", "class ReceivingAgent(TypeRoutedAgent):\n", @@ -332,11 +337,15 @@ } ], "source": [ + "from agnext.components import TypeSubscription\n", + "\n", "runtime = SingleThreadedAgentRuntime()\n", - "broadcaster = await runtime.register_and_get(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n", + "await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n", "await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n", + "await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n", + "await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n", "run_context = runtime.start()\n", - "await runtime.send_message(Message(\"Hello, World!\"), broadcaster)\n", + "await runtime.send_message(Message(\"Hello, World!\"), AgentId(\"broadcasting_agent\", \"default\"))\n", "await run_context.stop_when_idle()" ] }, @@ -367,10 +376,12 @@ "# Replace send_message with publish_message in the above example.\n", "\n", "runtime = SingleThreadedAgentRuntime()\n", - "broadcaster = await runtime.register_and_get(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n", + "await runtime.register(\"broadcasting_agent\", lambda: BroadcastingAgent(\"Broadcasting Agent\"))\n", "await runtime.register(\"receiving_agent\", lambda: ReceivingAgent(\"Receiving Agent\"))\n", + "await runtime.add_subscription(TypeSubscription(\"default\", \"broadcasting_agent\"))\n", + "await runtime.add_subscription(TypeSubscription(\"default\", \"receiving_agent\"))\n", "run_context = runtime.start()\n", - "await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), namespace=\"default\")\n", + "await runtime.publish_message(Message(\"Hello, World! From the runtime!\"), topic_id=TopicId(\"default\", \"default\"))\n", "await run_context.stop_when_idle()" ] }, diff --git a/python/docs/src/getting-started/model-clients.ipynb b/python/docs/src/getting-started/model-clients.ipynb index 048821044..41dceb113 100644 --- a/python/docs/src/getting-started/model-clients.ipynb +++ b/python/docs/src/getting-started/model-clients.ipynb @@ -318,8 +318,10 @@ ], "source": [ "# Create the runtime and register the agent.\n", + "from agnext.core import AgentId\n", + "\n", "runtime = SingleThreadedAgentRuntime()\n", - "agent = await runtime.register_and_get(\n", + "await runtime.register(\n", " \"simple-agent\",\n", " lambda: SimpleAgent(\n", " OpenAIChatCompletionClient(\n", @@ -332,7 +334,7 @@ "run_context = runtime.start()\n", "# Send a message to the agent and get the response.\n", "message = Message(\"Hello, what are some fun things to do in Seattle?\")\n", - "response = await runtime.send_message(message, agent)\n", + "response = await runtime.send_message(message, AgentId(\"simple-agent\", \"default\"))\n", "print(response.content)\n", "# Stop the runtime processing messages.\n", "await run_context.stop()" diff --git a/python/docs/src/getting-started/multi-agent-design-patterns.ipynb b/python/docs/src/getting-started/multi-agent-design-patterns.ipynb index 8cd4a39eb..720653f5f 100644 --- a/python/docs/src/getting-started/multi-agent-design-patterns.ipynb +++ b/python/docs/src/getting-started/multi-agent-design-patterns.ipynb @@ -131,7 +131,7 @@ " SystemMessage,\n", " UserMessage,\n", ")\n", - "from agnext.core import MessageContext" + "from agnext.core import MessageContext, TopicId" ] }, { @@ -201,7 +201,7 @@ " # Store the code review task in the session memory.\n", " self._session_memory[session_id].append(code_review_task)\n", " # Publish a code review task.\n", - " await self.publish_message(code_review_task)\n", + " await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n", "\n", " @message_handler\n", " async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:\n", @@ -220,7 +220,8 @@ " code=review_request.code,\n", " task=review_request.code_writing_task,\n", " review=message.review,\n", - " )\n", + " ),\n", + " topic_id=TopicId(\"default\", self.id.key),\n", " )\n", " print(\"Code Writing Result:\")\n", " print(\"-\" * 80)\n", @@ -259,7 +260,7 @@ " # Store the code review task in the session memory.\n", " self._session_memory[message.session_id].append(code_review_task)\n", " # Publish a new code review task.\n", - " await self.publish_message(code_review_task)\n", + " await self.publish_message(code_review_task, topic_id=TopicId(\"default\", self.id.key))\n", "\n", " def _extract_code_block(self, markdown_text: str) -> Union[str, None]:\n", " pattern = r\"```(\\w+)\\n(.*?)\\n```\"\n", @@ -360,7 +361,7 @@ " # Store the review result in the session memory.\n", " self._session_memory[message.session_id].append(result)\n", " # Publish the review result.\n", - " await self.publish_message(result)" + " await self.publish_message(result, topic_id=TopicId(\"default\", self.id.key))" ] }, { @@ -494,6 +495,7 @@ ], "source": [ "from agnext.application import SingleThreadedAgentRuntime\n", + "from agnext.components._type_subscription import TypeSubscription\n", "from agnext.components.models import OpenAIChatCompletionClient\n", "\n", "runtime = SingleThreadedAgentRuntime()\n", @@ -501,14 +503,16 @@ " \"ReviewerAgent\",\n", " lambda: ReviewerAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")),\n", ")\n", + "await runtime.add_subscription(TypeSubscription(\"default\", \"CoderAgent\"))\n", "await runtime.register(\n", " \"CoderAgent\",\n", " lambda: CoderAgent(model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\")),\n", ")\n", + "await runtime.add_subscription(TypeSubscription(\"default\", \"ReviewerAgent\"))\n", "run_context = runtime.start()\n", "await runtime.publish_message(\n", " message=CodeWritingTask(task=\"Write a function to find the sum of all even numbers in a list.\"),\n", - " namespace=\"default\",\n", + " topic_id=TopicId(\"default\", \"default\"),\n", ")\n", "\n", "# Keep processing messages until idle.\n", diff --git a/python/docs/src/getting-started/tools.ipynb b/python/docs/src/getting-started/tools.ipynb index 7a3ddc2bb..373a7b9f0 100644 --- a/python/docs/src/getting-started/tools.ipynb +++ b/python/docs/src/getting-started/tools.ipynb @@ -148,7 +148,7 @@ ")\n", "from agnext.components.tool_agent import ToolAgent, ToolException\n", "from agnext.components.tools import FunctionTool, Tool, ToolSchema\n", - "from agnext.core import AgentId, MessageContext\n", + "from agnext.core import AgentId, AgentInstantiationContext, MessageContext\n", "\n", "\n", "@dataclass\n", @@ -239,19 +239,19 @@ "# Create the tools.\n", "tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n", "# Register the agents.\n", - "tool_executor_agent = await runtime.register_and_get(\n", + "await runtime.register(\n", " \"tool-executor-agent\",\n", " lambda: ToolAgent(\n", " description=\"Tool Executor Agent\",\n", " tools=tools,\n", " ),\n", ")\n", - "tool_use_agent = await runtime.register_and_get(\n", + "await runtime.register(\n", " \"tool-use-agent\",\n", " lambda: ToolUseAgent(\n", " OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n", " tool_schema=[tool.schema for tool in tools],\n", - " tool_agent=tool_executor_agent,\n", + " tool_agent=AgentId(\"tool-executor-agent\", AgentInstantiationContext.current_agent_id().key),\n", " ),\n", ")" ] @@ -282,6 +282,7 @@ "# Start processing messages.\n", "run_context = runtime.start()\n", "# Send a direct message to the tool agent.\n", + "tool_use_agent = AgentId(\"tool-use-agent\", \"default\")\n", "response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n", "print(response.content)\n", "# Stop processing messages.\n", diff --git a/python/samples/byoa/langgraph_agent.py b/python/samples/byoa/langgraph_agent.py index 41c3ee43c..dcf17f921 100644 --- a/python/samples/byoa/langgraph_agent.py +++ b/python/samples/byoa/langgraph_agent.py @@ -10,7 +10,7 @@ from typing import Any, Callable, List, Literal from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import MessageContext +from agnext.core import AgentId, MessageContext from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool # pyright: ignore from langchain_openai import ChatOpenAI @@ -110,7 +110,7 @@ async def main() -> None: # Create runtime. runtime = SingleThreadedAgentRuntime() # Register the agent. - agent = await runtime.register_and_get( + await runtime.register( "langgraph_tool_use_agent", lambda: LangGraphToolUseAgent( "Tool use agent", @@ -118,6 +118,7 @@ async def main() -> None: [get_weather], ), ) + agent = AgentId("langgraph_tool_use_agent", key="default") # Start the runtime. run_context = runtime.start() # Send a message to the agent and get a response. diff --git a/python/samples/byoa/llamaindex_agent.py b/python/samples/byoa/llamaindex_agent.py index 72d06e7d5..16ae42121 100644 --- a/python/samples/byoa/llamaindex_agent.py +++ b/python/samples/byoa/llamaindex_agent.py @@ -9,7 +9,7 @@ from typing import List, Optional from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import MessageContext +from agnext.core import AgentId, MessageContext from llama_index.core import Settings from llama_index.core.agent import ReActAgent from llama_index.core.agent.runner.base import AgentRunner @@ -119,10 +119,11 @@ async def main() -> None: tools=[wikipedia_tool], llm=llm, max_iterations=8, memory=memory, verbose=True ) - agent = await runtime.register_and_get( + await runtime.register( "chat_agent", lambda: LlamaIndexAgent("Chat agent", llama_index_agent=llama_index_agent), ) + agent = AgentId("chat_agent", key="default") run_context = runtime.start() diff --git a/python/samples/common/agents/_chat_completion_agent.py b/python/samples/common/agents/_chat_completion_agent.py index bd48aae5b..3109e8030 100644 --- a/python/samples/common/agents/_chat_completion_agent.py +++ b/python/samples/common/agents/_chat_completion_agent.py @@ -110,8 +110,9 @@ class ChatCompletionAgent(TypeRoutedAgent): # Generate a response. response = await self._generate_response(message.response_format, ctx) + assert ctx.topic_id is not None # Publish the response. - await self.publish_message(response) + await self.publish_message(response, topic_id=ctx.topic_id) @message_handler() async def on_tool_call_message( diff --git a/python/samples/common/agents/_image_generation_agent.py b/python/samples/common/agents/_image_generation_agent.py index 31650e1c7..1aad9d9e4 100644 --- a/python/samples/common/agents/_image_generation_agent.py +++ b/python/samples/common/agents/_image_generation_agent.py @@ -7,8 +7,7 @@ from agnext.components import ( message_handler, ) from agnext.components.memory import ChatMemory -from agnext.core import MessageContext -from agnext.core._cancellation_token import CancellationToken +from agnext.core import CancellationToken, MessageContext from ..types import ( Message, @@ -58,7 +57,8 @@ class ImageGenerationAgent(TypeRoutedAgent): image is published as a MultiModalMessage.""" response = await self._generate_response(ctx.cancellation_token) - await self.publish_message(response) + assert ctx.topic_id is not None + await self.publish_message(response, topic_id=ctx.topic_id) async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage: messages = await self._memory.get_messages() diff --git a/python/samples/common/agents/_oai_assistant.py b/python/samples/common/agents/_oai_assistant.py index 75a46d077..954c957f8 100644 --- a/python/samples/common/agents/_oai_assistant.py +++ b/python/samples/common/agents/_oai_assistant.py @@ -80,7 +80,8 @@ class OpenAIAssistantAgent(TypeRoutedAgent): async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: """Handle a publish now message. This method generates a response and publishes it.""" response = await self._generate_response(message.response_format, ctx.cancellation_token) - await self.publish_message(response) + assert ctx.topic_id is not None + await self.publish_message(response, ctx.topic_id) async def _generate_response( self, diff --git a/python/samples/common/agents/_user_proxy.py b/python/samples/common/agents/_user_proxy.py index 723490533..c04c59378 100644 --- a/python/samples/common/agents/_user_proxy.py +++ b/python/samples/common/agents/_user_proxy.py @@ -23,7 +23,8 @@ class UserProxyAgent(TypeRoutedAgent): async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: """Handle a publish now message. This method prompts the user for input, then publishes it.""" user_input = await self.get_user_input(self._user_input_prompt) - await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"])) + assert ctx.topic_id is not None + await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id) async def get_user_input(self, prompt: str) -> str: """Get user input from the console. Override this method to customize how user input is retrieved.""" diff --git a/python/samples/core/inner_outer_direct.py b/python/samples/core/inner_outer_direct.py index cc4fe3c53..d193773bc 100644 --- a/python/samples/core/inner_outer_direct.py +++ b/python/samples/core/inner_outer_direct.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentId, MessageContext +from agnext.core import AgentId, AgentInstantiationContext, MessageContext @dataclass @@ -45,8 +45,9 @@ class Outer(TypeRoutedAgent): async def main() -> None: runtime = SingleThreadedAgentRuntime() - inner = await runtime.register_and_get("inner", Inner) - outer = await runtime.register_and_get("outer", lambda: Outer(inner)) + await runtime.register("inner", Inner) + await runtime.register("outer", lambda: Outer(AgentId("outer", AgentInstantiationContext.current_agent_id().key))) + outer = AgentId("outer", "default") run_context = runtime.start() diff --git a/python/samples/core/one_agent_direct.py b/python/samples/core/one_agent_direct.py index 6319610ba..4775716f5 100644 --- a/python/samples/core/one_agent_direct.py +++ b/python/samples/core/one_agent_direct.py @@ -17,6 +17,7 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) +from agnext.core import AgentId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -45,10 +46,11 @@ class ChatCompletionAgent(TypeRoutedAgent): async def main() -> None: runtime = SingleThreadedAgentRuntime() - agent = await runtime.register_and_get( + await runtime.register( "chat_agent", lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-4o-mini")), ) + agent = AgentId("chat_agent", "default") run_context = runtime.start() diff --git a/python/samples/core/two_agents_pub_sub.py b/python/samples/core/two_agents_pub_sub.py index 8532781ad..30ce6ed1b 100644 --- a/python/samples/core/two_agents_pub_sub.py +++ b/python/samples/core/two_agents_pub_sub.py @@ -18,6 +18,7 @@ from typing import List from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler +from agnext.components._type_subscription import TypeSubscription from agnext.components.models import ( AssistantMessage, ChatCompletionClient, @@ -25,6 +26,7 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) +from agnext.core import AgentId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -69,7 +71,11 @@ class ChatCompletionAgent(TypeRoutedAgent): llm_messages.append(UserMessage(content=m.content, source=m.source)) response = await self._model_client.create(self._system_messages + llm_messages) assert isinstance(response.content, str) - await self.publish_message(Message(content=response.content, source=self.metadata["type"])) + + if ctx.topic_id is not None: + await self.publish_message( + Message(content=response.content, source=self.metadata["type"]), topic_id=ctx.topic_id + ) async def main() -> None: @@ -77,7 +83,7 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register the agents. - jack = await runtime.register_and_get( + await runtime.register( "Jack", lambda: ChatCompletionAgent( description="Jack a comedian", @@ -88,7 +94,8 @@ async def main() -> None: termination_word="TERMINATE", ), ) - await runtime.register_and_get( + await runtime.add_subscription(TypeSubscription("default", "Jack")) + await runtime.register( "Cathy", lambda: ChatCompletionAgent( description="Cathy a poet", @@ -99,12 +106,13 @@ async def main() -> None: termination_word="TERMINATE", ), ) + await runtime.add_subscription(TypeSubscription("default", "Cathy")) run_context = runtime.start() # Send a message to Jack to start the conversation. message = Message(content="Can you tell me something fun about SF?", source="User") - await runtime.send_message(message, jack) + await runtime.send_message(message, AgentId("jack", "default")) # Process messages. await run_context.stop_when_idle() diff --git a/python/samples/demos/assistant.py b/python/samples/demos/assistant.py index e231f213c..684543fe2 100644 --- a/python/samples/demos/assistant.py +++ b/python/samples/demos/assistant.py @@ -13,7 +13,7 @@ import aiofiles import openai from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentId, AgentRuntime, CancellationToken +from agnext.core import AgentId, AgentRuntime, MessageContext from openai import AsyncAssistantEventHandler from openai.types.beta.thread import ToolResources from openai.types.beta.threads import Message, Text, TextDelta @@ -22,6 +22,7 @@ from typing_extensions import override sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from agnext.core import AgentInstantiationContext from common.agents import OpenAIAssistantAgent from common.memory import BufferedChatMemory from common.patterns._group_chat_manager import GroupChatManager @@ -30,7 +31,7 @@ from common.types import PublishNow, TextMessage sep = "-" * 50 -class UserProxyAgent(TypeRoutedAgent): # type: ignore +class UserProxyAgent(TypeRoutedAgent): def __init__( # type: ignore self, client: openai.AsyncClient, # type: ignore @@ -47,7 +48,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore self._vector_store_id = vector_store_id @message_handler() # type: ignore - async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore + async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None: # TODO: render image if message has image. # print(f"{message.source}: {message.content}") pass @@ -57,7 +58,7 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore return await loop.run_in_executor(None, input, prompt) @message_handler() # type: ignore - async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: # type: ignore + async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: while True: user_input = await self._get_user_input(f"\n{sep}\nYou: ") # Parse upload file command '[upload code_interpreter | file_search filename]'. @@ -108,7 +109,10 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore return else: # Publish user input and exit handler. - await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"])) + assert ctx.topic_id is not None + await self.publish_message( + TextMessage(content=user_input, source=self.metadata["type"]), topic_id=ctx.topic_id + ) return @@ -166,7 +170,7 @@ class EventHandler(AsyncAssistantEventHandler): print("\n".join(citations)) -async def assistant_chat(runtime: AgentRuntime) -> AgentId: +async def assistant_chat(runtime: AgentRuntime) -> str: oai_assistant = openai.beta.assistants.create( model="gpt-4-turbo", description="An AI assistant that helps with everyday tasks.", @@ -177,7 +181,7 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId: thread = openai.beta.threads.create( tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, ) - assistant = await runtime.register_and_get( + await runtime.register( "Assistant", lambda: OpenAIAssistantAgent( description="An AI assistant that helps with everyday tasks.", @@ -188,7 +192,7 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId: ), ) - user = await runtime.register_and_get( + await runtime.register( "User", lambda: UserProxyAgent( client=openai.AsyncClient(), @@ -203,10 +207,13 @@ async def assistant_chat(runtime: AgentRuntime) -> AgentId: lambda: GroupChatManager( description="A group chat manager.", memory=BufferedChatMemory(buffer_size=10), - participants=[assistant, user], + participants=[ + AgentId("Assistant", AgentInstantiationContext.current_agent_id().key), + AgentId("User", AgentInstantiationContext.current_agent_id().key), + ], ), ) - return user + return "User" async def main() -> None: @@ -229,7 +236,7 @@ Type "exit" to exit the chat. _run_context = runtime.start() print(usage) # Request the user to start the conversation. - await runtime.send_message(PublishNow(), user) + await runtime.send_message(PublishNow(), AgentId(user, "default")) # TODO: have a way to exit the loop. diff --git a/python/samples/demos/chat_room.py b/python/samples/demos/chat_room.py index e6ea5dba4..bf4ffa638 100644 --- a/python/samples/demos/chat_room.py +++ b/python/samples/demos/chat_room.py @@ -9,7 +9,7 @@ from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler from agnext.components.memory import ChatMemory from agnext.components.models import ChatCompletionClient, SystemMessage -from agnext.core import AgentInstantiationContext, AgentRuntime +from agnext.core import AgentId, AgentInstantiationContext, AgentProxy, AgentRuntime sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -76,7 +76,10 @@ Use the following JSON format to provide your thought on the latest message and # Publish the response if needed. if respond is True or str(respond).lower().strip() == "true": - await self.publish_message(TextMessage(source=self.metadata["type"], content=str(response))) + assert ctx.topic_id is not None + await self.publish_message( + TextMessage(source=self.metadata["type"], content=str(response)), topic_id=ctx.topic_id + ) class ChatRoomUserAgent(TextualUserAgent): @@ -96,7 +99,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: app=app, ), ) - alice = await runtime.register_and_get_proxy( + await runtime.register( "Alice", lambda: ChatRoomAgent( name=AgentInstantiationContext.current_agent_id().type, @@ -106,7 +109,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) - bob = await runtime.register_and_get_proxy( + alice = AgentProxy(AgentId("Alice", "default"), runtime) + await runtime.register( "Bob", lambda: ChatRoomAgent( name=AgentInstantiationContext.current_agent_id().type, @@ -116,7 +120,8 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) - charlie = await runtime.register_and_get_proxy( + bob = AgentProxy(AgentId("Bob", "default"), runtime) + await runtime.register( "Charlie", lambda: ChatRoomAgent( name=AgentInstantiationContext.current_agent_id().type, @@ -126,6 +131,7 @@ async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) + charlie = AgentProxy(AgentId("Charlie", "default"), runtime) app.welcoming_notice = f"""Welcome to the chat room demo with the following participants: 1. 👧 {alice.id.type}: {(await alice.metadata)['description']} 2. 👱🏼‍♂️ {bob.id.type}: {(await bob.metadata)['description']} diff --git a/python/samples/demos/chess_game.py b/python/samples/demos/chess_game.py index 45ea19926..4c98c1cac 100644 --- a/python/samples/demos/chess_game.py +++ b/python/samples/demos/chess_game.py @@ -10,14 +10,16 @@ import sys from typing import Annotated, Literal from agnext.application import SingleThreadedAgentRuntime +from agnext.components._type_subscription import TypeSubscription from agnext.components.models import SystemMessage from agnext.components.tools import FunctionTool -from agnext.core import AgentRuntime +from agnext.core import AgentInstantiationContext, AgentRuntime, TopicId from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move from chess import piece_name as get_piece_name sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from agnext.core import AgentId from common.agents._chat_completion_agent import ChatCompletionAgent from common.memory import BufferedChatMemory from common.patterns._group_chat_manager import GroupChatManager @@ -156,7 +158,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore ), ] - black = await runtime.register_and_get( + await runtime.register( "PlayerBlack", lambda: ChatCompletionAgent( description="Player playing black.", @@ -173,7 +175,8 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore tools=black_tools, ), ) - white = await runtime.register_and_get( + await runtime.add_subscription(TypeSubscription("default", "PlayerBlack")) + await runtime.register( "PlayerWhite", lambda: ChatCompletionAgent( description="Player playing white.", @@ -190,6 +193,7 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore tools=white_tools, ), ) + await runtime.add_subscription(TypeSubscription("default", "PlayerWhite")) # Create a group chat manager for the chess game to orchestrate a turn-based # conversation between the two agents. await runtime.register( @@ -197,7 +201,10 @@ async def chess_game(runtime: AgentRuntime) -> None: # type: ignore lambda: GroupChatManager( description="A chess game between two agents.", memory=BufferedChatMemory(buffer_size=10), - participants=[white, black], # white goes first + participants=[ + AgentId("PlayerWhite", AgentInstantiationContext.current_agent_id().key), + AgentId("PlayerBlack", AgentInstantiationContext.current_agent_id().key), + ], # white goes first ), ) @@ -207,7 +214,9 @@ async def main() -> None: await chess_game(runtime) run_context = runtime.start() # Publish an initial message to trigger the group chat manager to start orchestration. - await runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default") + await runtime.publish_message( + TextMessage(content="Game started.", source="System"), topic_id=TopicId("default", "default") + ) await run_context.stop_when_idle() diff --git a/python/samples/demos/illustrator_critics.py b/python/samples/demos/illustrator_critics.py index 4e5d9a574..89821fd0d 100644 --- a/python/samples/demos/illustrator_critics.py +++ b/python/samples/demos/illustrator_critics.py @@ -7,11 +7,12 @@ import sys import openai from agnext.application import SingleThreadedAgentRuntime from agnext.components.models import SystemMessage -from agnext.core import AgentRuntime +from agnext.core import AgentInstantiationContext, AgentRuntime sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import AgentId, AgentProxy from common.agents import ChatCompletionAgent, ImageGenerationAgent from common.memory import BufferedChatMemory from common.patterns._group_chat_manager import GroupChatManager @@ -27,7 +28,7 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non app=app, ), ) - descriptor = await runtime.register_and_get_proxy( + await runtime.register( "Descriptor", lambda: ChatCompletionAgent( description="An AI agent that provides a description of the image.", @@ -46,7 +47,8 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo", max_tokens=500), ), ) - illustrator = await runtime.register_and_get_proxy( + descriptor = AgentProxy(AgentId("Descriptor", "default"), runtime) + await runtime.register( "Illustrator", lambda: ImageGenerationAgent( description="An AI agent that generates images.", @@ -55,7 +57,8 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non memory=BufferedChatMemory(buffer_size=1), ), ) - critic = await runtime.register_and_get_proxy( + illustrator = AgentProxy(AgentId("Illustrator", "default"), runtime) + await runtime.register( "Critic", lambda: ChatCompletionAgent( description="An AI agent that provides feedback on images given user's requirements.", @@ -74,12 +77,17 @@ async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> Non model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) + critic = AgentProxy(AgentId("Critic", "default"), runtime) await runtime.register( "GroupChatManager", lambda: GroupChatManager( description="A chat manager that handles group chat.", memory=BufferedChatMemory(buffer_size=5), - participants=[illustrator.id, critic.id, descriptor.id], + participants=[ + AgentId("Illustrator", AgentInstantiationContext.current_agent_id().key), + AgentId("Descriptor", AgentInstantiationContext.current_agent_id().key), + AgentId("Critic", AgentInstantiationContext.current_agent_id().key), + ], termination_word="APPROVE", ), ) diff --git a/python/samples/demos/software_consultancy.py b/python/samples/demos/software_consultancy.py index b239fbca8..1b9aaeffb 100644 --- a/python/samples/demos/software_consultancy.py +++ b/python/samples/demos/software_consultancy.py @@ -19,7 +19,7 @@ import openai from agnext.application import SingleThreadedAgentRuntime from agnext.components.models import SystemMessage from agnext.components.tools import FunctionTool -from agnext.core import AgentRuntime +from agnext.core import AgentInstantiationContext, AgentRuntime from markdownify import markdownify # type: ignore from tqdm import tqdm from typing_extensions import Annotated @@ -27,6 +27,7 @@ from typing_extensions import Annotated sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import AgentId from common.agents import ChatCompletionAgent from common.memory import HeadAndTailChatMemory from common.patterns._group_chat_manager import GroupChatManager @@ -106,14 +107,14 @@ async def create_image( async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore - user_agent = await runtime.register_and_get( + await runtime.register( "Customer", lambda: TextualUserAgent( description="A customer looking for help.", app=app, ), ) - developer = await runtime.register_and_get( + await runtime.register( "Developer", lambda: ChatCompletionAgent( description="A Python software developer.", @@ -149,11 +150,11 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No FunctionTool(list_files, name="list_files", description="List files in a directory."), FunctionTool(browse_web, name="browse_web", description="Browse a web page."), ], - tool_approver=user_agent, + tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key), ), ) - product_manager = await runtime.register_and_get( + await runtime.register( "ProductManager", lambda: ChatCompletionAgent( description="A product manager. " @@ -179,10 +180,10 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No FunctionTool(list_files, name="list_files", description="List files in a directory."), FunctionTool(browse_web, name="browse_web", description="Browse a web page."), ], - tool_approver=user_agent, + tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key), ), ) - ux_designer = await runtime.register_and_get( + await runtime.register( "UserExperienceDesigner", lambda: ChatCompletionAgent( description="A user experience designer for creating user interfaces.", @@ -211,11 +212,11 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No ), FunctionTool(list_files, name="list_files", description="List files in a directory."), ], - tool_approver=user_agent, + tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key), ), ) - illustrator = await runtime.register_and_get( + await runtime.register( "Illustrator", lambda: ChatCompletionAgent( description="An illustrator for creating images.", @@ -237,7 +238,7 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No description="Create an image from a description.", ), ], - tool_approver=user_agent, + tool_approver=AgentId("Customer", AgentInstantiationContext.current_agent_id().key), ), ) await runtime.register( @@ -246,7 +247,13 @@ async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> No description="A group chat manager.", memory=HeadAndTailChatMemory(head_size=1, tail_size=10), model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), - participants=[developer, product_manager, ux_designer, illustrator, user_agent], + participants=[ + AgentId("Developer", AgentInstantiationContext.current_agent_id().key), + AgentId("ProductManager", AgentInstantiationContext.current_agent_id().key), + AgentId("UserExperienceDesigner", AgentInstantiationContext.current_agent_id().key), + AgentId("Illustrator", AgentInstantiationContext.current_agent_id().key), + AgentId("Customer", AgentInstantiationContext.current_agent_id().key), + ], ), ) art = r""" diff --git a/python/samples/demos/utils.py b/python/samples/demos/utils.py index cf3e05c15..baf303c29 100644 --- a/python/samples/demos/utils.py +++ b/python/samples/demos/utils.py @@ -13,6 +13,7 @@ from textual_imageview.viewer import ImageViewer sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from agnext.core import TopicId from common.types import ( MultiModalMessage, PublishNow, @@ -135,7 +136,9 @@ class TextualChatApp(App): # type: ignore chat_messages.query("#typing").remove() # Publish the user message to the runtime. await self._runtime.publish_message( - TextMessage(source=self._user_name, content=user_input), namespace="default" + # TODO fix hard coded topic_id + TextMessage(source=self._user_name, content=user_input), + topic_id=TopicId("default", "default"), ) async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore diff --git a/python/samples/marketing-agents/app.py b/python/samples/marketing-agents/app.py index 5794ede92..1070495d9 100644 --- a/python/samples/marketing-agents/app.py +++ b/python/samples/marketing-agents/app.py @@ -1,5 +1,6 @@ import os +from agnext.components._type_subscription import TypeSubscription from agnext.components.models import AzureOpenAIChatCompletionClient from agnext.core import AgentRuntime from auditor import AuditAgent @@ -28,7 +29,6 @@ async def build_app(runtime: AgentRuntime) -> None: ) await runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client)) + await runtime.add_subscription(TypeSubscription("default", "GraphicDesigner")) await runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client)) - - await runtime.get("GraphicDesigner") - await runtime.get("Auditor") + await runtime.add_subscription(TypeSubscription("default", "Auditor")) diff --git a/python/samples/marketing-agents/auditor.py b/python/samples/marketing-agents/auditor.py index 4ac299dfe..5d2366f05 100644 --- a/python/samples/marketing-agents/auditor.py +++ b/python/samples/marketing-agents/auditor.py @@ -30,4 +30,7 @@ class AuditAgent(TypeRoutedAgent): assert isinstance(completion.content, str) if "NOTFORME" in completion.content: return - await self.publish_message(AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content)) + assert ctx.topic_id is not None + await self.publish_message( + AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content), topic_id=ctx.topic_id + ) diff --git a/python/samples/marketing-agents/graphic_designer.py b/python/samples/marketing-agents/graphic_designer.py index b997a1f88..b7c3b879a 100644 --- a/python/samples/marketing-agents/graphic_designer.py +++ b/python/samples/marketing-agents/graphic_designer.py @@ -33,6 +33,9 @@ class GraphicDesignerAgent(TypeRoutedAgent): image_uri = response.data[0].url logger.info(f"Generated image for article. Got response: '{image_uri}'") - await self.publish_message(GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri)) + assert ctx.topic_id is not None + await self.publish_message( + GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri), topic_id=ctx.topic_id + ) except Exception as e: logger.error(f"Failed to generate image for article. Error: {e}") diff --git a/python/samples/marketing-agents/test_usage.py b/python/samples/marketing-agents/test_usage.py index 81b03e515..a74e7855e 100644 --- a/python/samples/marketing-agents/test_usage.py +++ b/python/samples/marketing-agents/test_usage.py @@ -3,7 +3,7 @@ import os from agnext.application import SingleThreadedAgentRuntime from agnext.components import Image, TypeRoutedAgent, message_handler -from agnext.core import MessageContext +from agnext.core import MessageContext, TopicId from app import build_app from dotenv import load_dotenv from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated @@ -34,13 +34,15 @@ async def main() -> None: ctx = runtime.start() + topic_id = TopicId("default", "default") + await runtime.publish_message( - AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), namespace="default" + AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), topic_id=topic_id ) await runtime.publish_message( ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"), - namespace="default", + topic_id=topic_id, ) await ctx.stop_when_idle() diff --git a/python/samples/marketing-agents/worker.py b/python/samples/marketing-agents/worker.py index 241448c0e..0f09a3b20 100644 --- a/python/samples/marketing-agents/worker.py +++ b/python/samples/marketing-agents/worker.py @@ -2,7 +2,7 @@ import asyncio import logging from agnext.application import WorkerAgentRuntime -from agnext.core._serialization import MESSAGE_TYPE_REGISTRY +from agnext.core import MESSAGE_TYPE_REGISTRY from app import build_app from dotenv import load_dotenv from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated diff --git a/python/samples/patterns/coder_executor.py b/python/samples/patterns/coder_executor.py index 09278e63d..f03379c75 100644 --- a/python/samples/patterns/coder_executor.py +++ b/python/samples/patterns/coder_executor.py @@ -22,6 +22,7 @@ from typing import Dict, List from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler +from agnext.components._type_subscription import TypeSubscription from agnext.components.code_executor import CodeBlock, CodeExecutor, LocalCommandLineCodeExecutor from agnext.components.models import ( AssistantMessage, @@ -30,6 +31,7 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) +from agnext.core import TopicId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -100,10 +102,12 @@ Reply "TERMINATE" in the end when everything is done.""" AssistantMessage(content=response.content, source=self.metadata["type"]) ) + assert ctx.topic_id is not None # Publish the code execution task. await self.publish_message( CodeExecutionTask(content=response.content, session_id=session_id), cancellation_token=ctx.cancellation_token, + topic_id=ctx.topic_id, ) @message_handler @@ -120,8 +124,11 @@ Reply "TERMINATE" in the end when everything is done.""" if "TERMINATE" in response.content: # If the task is completed, publish a message with the completion content. + assert ctx.topic_id is not None await self.publish_message( - TaskCompletion(content=response.content), cancellation_token=ctx.cancellation_token + TaskCompletion(content=response.content), + cancellation_token=ctx.cancellation_token, + topic_id=ctx.topic_id, ) print("--------------------") print("Task completed:") @@ -129,9 +136,11 @@ Reply "TERMINATE" in the end when everything is done.""" return # Publish the code execution task. + assert ctx.topic_id is not None await self.publish_message( CodeExecutionTask(content=response.content, session_id=message.session_id), cancellation_token=ctx.cancellation_token, + topic_id=ctx.topic_id, ) @@ -148,11 +157,13 @@ class Executor(TypeRoutedAgent): code_blocks = self._extract_code_blocks(message.content) if not code_blocks: # If no code block is found, publish a message with an error. + assert ctx.topic_id is not None await self.publish_message( CodeExecutionTaskResult( output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id ), cancellation_token=ctx.cancellation_token, + topic_id=ctx.topic_id, ) return # Execute code blocks. @@ -160,9 +171,11 @@ class Executor(TypeRoutedAgent): code_blocks=code_blocks, cancellation_token=ctx.cancellation_token ) # Publish the code execution result. + assert ctx.topic_id is not None await self.publish_message( CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id), cancellation_token=ctx.cancellation_token, + topic_id=ctx.topic_id, ) def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]: @@ -185,10 +198,12 @@ async def main(task: str, temp_dir: str) -> None: "coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo")) ) await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir))) + await runtime.add_subscription(TypeSubscription("default", "coder")) + await runtime.add_subscription(TypeSubscription("default", "executor")) run_context = runtime.start() # Publish the task message. - await runtime.publish_message(TaskMessage(content=task), namespace="default") + await runtime.publish_message(TaskMessage(content=task), topic_id=TopicId("default", "default")) await run_context.stop_when_idle() diff --git a/python/samples/patterns/coder_reviewer.py b/python/samples/patterns/coder_reviewer.py index ccd0e617a..1a2e0d5f8 100644 --- a/python/samples/patterns/coder_reviewer.py +++ b/python/samples/patterns/coder_reviewer.py @@ -22,6 +22,7 @@ from typing import Dict, List, Union from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler +from agnext.components._type_subscription import TypeSubscription from agnext.components.models import ( AssistantMessage, ChatCompletionClient, @@ -29,6 +30,7 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) +from agnext.core import TopicId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -110,12 +112,14 @@ Please review the code and provide feedback. review_text = "Code review:\n" + "\n".join([f"{k}: {v}" for k, v in review.items()]) approved = review["approval"].lower().strip() == "approve" # Publish the review result. + assert ctx.topic_id is not None await self.publish_message( CodeReviewResult( review=review_text, approved=approved, session_id=message.session_id, - ) + ), + topic_id=ctx.topic_id, ) @@ -179,7 +183,11 @@ Code: # Store the code review task in the session memory. self._session_memory[session_id].append(code_review_task) # Publish a code review task. - await self.publish_message(code_review_task) + assert ctx.topic_id is not None + await self.publish_message( + code_review_task, + topic_id=ctx.topic_id, + ) @message_handler async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None: @@ -193,12 +201,14 @@ Code: # Check if the code is approved. if message.approved: # Publish the code writing result. + assert ctx.topic_id is not None await self.publish_message( CodeWritingResult( code=review_request.code, task=review_request.code_writing_task, review=message.review, - ) + ), + topic_id=ctx.topic_id, ) print("Code Writing Result:") print("-" * 80) @@ -237,7 +247,11 @@ Code: # Store the code review task in the session memory. self._session_memory[message.session_id].append(code_review_task) # Publish a new code review task. - await self.publish_message(code_review_task) + assert ctx.topic_id is not None + await self.publish_message( + code_review_task, + topic_id=ctx.topic_id, + ) def _extract_code_block(self, markdown_text: str) -> Union[str, None]: pattern = r"```(\w+)\n(.*?)\n```" @@ -258,6 +272,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), ), ) + await runtime.add_subscription(TypeSubscription("default", "ReviewerAgent")) await runtime.register( "CoderAgent", lambda: CoderAgent( @@ -265,12 +280,13 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), ), ) + await runtime.add_subscription(TypeSubscription("default", "CoderAgent")) run_context = runtime.start() await runtime.publish_message( message=CodeWritingTask( task="Write a function to find the directory with the largest number of files using multi-processing." ), - namespace="default", + topic_id=TopicId("default", "default"), ) # Keep processing messages until idle. diff --git a/python/samples/patterns/group_chat.py b/python/samples/patterns/group_chat.py index d500554ca..e39db257f 100644 --- a/python/samples/patterns/group_chat.py +++ b/python/samples/patterns/group_chat.py @@ -26,7 +26,7 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import AgentId +from agnext.core import AgentId, AgentInstantiationContext, TopicId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -69,7 +69,8 @@ class RoundRobinGroupChatManager(TypeRoutedAgent): self._round_count += 1 if self._round_count > self._num_rounds * len(self._participants): # End the conversation after the specified number of rounds. - await self.publish_message(Termination()) + assert ctx.topic_id is not None + await self.publish_message(Termination(), ctx.topic_id) return # Send a request to speak message to the selected speaker. await self.send_message(RequestToSpeak(), speaker) @@ -104,9 +105,10 @@ class GroupChatParticipant(TypeRoutedAgent): llm_messages.append(UserMessage(content=m.content, source=m.source)) response = await self._model_client.create(self._system_messages + llm_messages) assert isinstance(response.content, str) - speach = Message(content=response.content, source=self.metadata["type"]) - self._memory.append(speach) - await self.publish_message(speach) + speech = Message(content=response.content, source=self.metadata["type"]) + self._memory.append(speech) + assert ctx.topic_id is not None + await self.publish_message(speech, topic_id=ctx.topic_id) async def main() -> None: @@ -114,7 +116,7 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register the participants. - agent1 = await runtime.register_and_get( + await runtime.register( "DataScientist", lambda: GroupChatParticipant( description="A data scientist", @@ -122,7 +124,8 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), ), ) - agent2 = await runtime.register_and_get( + + await runtime.register( "Engineer", lambda: GroupChatParticipant( description="An engineer", @@ -130,7 +133,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), ), ) - agent3 = await runtime.register_and_get( + await runtime.register( "Artist", lambda: GroupChatParticipant( description="An artist", @@ -144,7 +147,11 @@ async def main() -> None: "GroupChatManager", lambda: RoundRobinGroupChatManager( description="A group chat manager", - participants=[agent1, agent2, agent3], + participants=[ + AgentId("DataScientist", AgentInstantiationContext.current_agent_id().key), + AgentId("Engineer", AgentInstantiationContext.current_agent_id().key), + AgentId("Artist", AgentInstantiationContext.current_agent_id().key), + ], num_rounds=3, ), ) @@ -153,7 +160,9 @@ async def main() -> None: run_context = runtime.start() # Start the conversation. - await runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default") + await runtime.publish_message( + Message(content="Hello, everyone!", source="Moderator"), topic_id=TopicId("default", "default") + ) await run_context.stop_when_idle() diff --git a/python/samples/patterns/mixture_of_agents.py b/python/samples/patterns/mixture_of_agents.py index a99a1fb9b..58cb32091 100644 --- a/python/samples/patterns/mixture_of_agents.py +++ b/python/samples/patterns/mixture_of_agents.py @@ -16,11 +16,13 @@ from typing import Dict, List from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler +from agnext.components._type_subscription import TypeSubscription from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage from agnext.core import MessageContext sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import TopicId from common.utils import get_chat_completion_client_from_envs @@ -66,7 +68,8 @@ class ReferenceAgent(TypeRoutedAgent): response = await self._model_client.create(self._system_messages + [task_message]) assert isinstance(response.content, str) task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content) - await self.publish_message(task_result) + assert ctx.topic_id is not None + await self.publish_message(task_result, topic_id=ctx.topic_id) class AggregatorAgent(TypeRoutedAgent): @@ -90,7 +93,8 @@ class AggregatorAgent(TypeRoutedAgent): """Handle a task message. This method publishes the task to the reference agents.""" session_id = str(uuid.uuid4()) ref_task = ReferenceAgentTask(session_id=session_id, task=message.task) - await self.publish_message(ref_task) + assert ctx.topic_id is not None + await self.publish_message(ref_task, topic_id=ctx.topic_id) @message_handler async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None: @@ -104,7 +108,8 @@ class AggregatorAgent(TypeRoutedAgent): ) assert isinstance(response.content, str) task_result = AggregatorTaskResult(result=response.content) - await self.publish_message(task_result) + assert ctx.topic_id is not None + await self.publish_message(task_result, topic_id=ctx.topic_id) self._session_results.pop(message.session_id) print(f"Aggregator result: {response.content}") @@ -120,6 +125,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.1), ), ) + await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent1")) await runtime.register( "ReferenceAgent2", lambda: ReferenceAgent( @@ -128,6 +134,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=0.5), ), ) + await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent2")) await runtime.register( "ReferenceAgent3", lambda: ReferenceAgent( @@ -136,6 +143,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini", temperature=1.0), ), ) + await runtime.add_subscription(TypeSubscription("default", "ReferenceAgent3")) await runtime.register( "AggregatorAgent", lambda: AggregatorAgent( @@ -149,8 +157,11 @@ async def main() -> None: num_references=3, ), ) + await runtime.add_subscription(TypeSubscription("default", "AggregatorAgent")) run_context = runtime.start() - await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default") + await runtime.publish_message( + AggregatorTask(task="What are something fun to do in SF?"), topic_id=TopicId("default", "default") + ) # Keep processing messages. await run_context.stop_when_idle() diff --git a/python/samples/patterns/multi_agent_debate.py b/python/samples/patterns/multi_agent_debate.py index 743f9c6fe..225c52a87 100644 --- a/python/samples/patterns/multi_agent_debate.py +++ b/python/samples/patterns/multi_agent_debate.py @@ -41,6 +41,7 @@ from typing import Dict, List, Tuple from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler +from agnext.components._type_subscription import TypeSubscription from agnext.components.models import ( AssistantMessage, ChatCompletionClient, @@ -48,6 +49,7 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) +from agnext.core import TopicId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -163,9 +165,12 @@ class MathSolver(TypeRoutedAgent): answer = match.group(1) # Increment the counter. self._counters[message.session_id] = self._counters.get(message.session_id, 0) + 1 + assert ctx.topic_id is not None if self._counters[message.session_id] == self._max_round: # If the counter reaches the maximum round, publishes a final response. - await self.publish_message(FinalSolverResponse(answer=answer, session_id=message.session_id)) + await self.publish_message( + FinalSolverResponse(answer=answer, session_id=message.session_id), topic_id=ctx.topic_id + ) else: # Publish intermediate response. await self.publish_message( @@ -175,7 +180,8 @@ class MathSolver(TypeRoutedAgent): answer=answer, session_id=message.session_id, round=self._counters[message.session_id], - ) + ), + topic_id=ctx.topic_id, ) @@ -193,7 +199,10 @@ class MathAggregator(TypeRoutedAgent): "in the form of {{answer}}, at the end of your response." ) session_id = str(uuid.uuid4()) - await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content)) + assert ctx.topic_id is not None + await self.publish_message( + SolverRequest(content=prompt, session_id=session_id, question=message.content), topic_id=ctx.topic_id + ) @message_handler async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None: @@ -203,7 +212,8 @@ class MathAggregator(TypeRoutedAgent): answers = [resp.answer for resp in self._responses[message.session_id]] majority_answer = max(set(answers), key=answers.count) # Publish the aggregated response. - await self.publish_message(Answer(content=majority_answer)) + assert ctx.topic_id is not None + await self.publish_message(Answer(content=majority_answer), topic_id=ctx.topic_id) # Clear the responses. self._responses.pop(message.session_id) print(f"Aggregated answer: {majority_answer}") @@ -223,6 +233,7 @@ async def main(question: str) -> None: max_round=3, ), ) + await runtime.add_subscription(TypeSubscription("default", "MathSolver1")) await runtime.register( "MathSolver2", lambda: MathSolver( @@ -231,6 +242,7 @@ async def main(question: str) -> None: max_round=3, ), ) + await runtime.add_subscription(TypeSubscription("default", "MathSolver2")) await runtime.register( "MathSolver3", lambda: MathSolver( @@ -239,6 +251,7 @@ async def main(question: str) -> None: max_round=3, ), ) + await runtime.add_subscription(TypeSubscription("default", "MathSolver3")) await runtime.register( "MathSolver4", lambda: MathSolver( @@ -247,13 +260,14 @@ async def main(question: str) -> None: max_round=3, ), ) + await runtime.add_subscription(TypeSubscription("default", "MathSolver4")) # Register the aggregator agent. await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4)) run_context = runtime.start() # Send a math problem to the aggregator agent. - await runtime.publish_message(Question(content=question), namespace="default") + await runtime.publish_message(Question(content=question), topic_id=TopicId("default", "default")) await run_context.stop_when_idle() diff --git a/python/samples/tool-use/coding_direct.py b/python/samples/tool-use/coding_direct.py index 5f278d114..fcb1799a9 100644 --- a/python/samples/tool-use/coding_direct.py +++ b/python/samples/tool-use/coding_direct.py @@ -30,7 +30,7 @@ from agnext.components.models import ( ) from agnext.components.tool_agent import ToolAgent, ToolException from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema -from agnext.core import AgentId +from agnext.core import AgentId, AgentInstantiationContext sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -107,21 +107,21 @@ async def main() -> None: ) ] # Register agents. - tool_executor_agent = await runtime.register_and_get( + await runtime.register( "tool_executor_agent", lambda: ToolAgent( description="Tool Executor Agent", tools=tools, ), ) - tool_use_agent = await runtime.register_and_get( + await runtime.register( "tool_enabled_agent", lambda: ToolUseAgent( description="Tool Use Agent", system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")], model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), tool_schema=[tool.schema for tool in tools], - tool_agent=tool_executor_agent, + tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key), ), ) @@ -129,7 +129,7 @@ async def main() -> None: # Send a task to the tool user. response = await runtime.send_message( - Message("Run the following Python code: print('Hello, World!')"), tool_use_agent + Message("Run the following Python code: print('Hello, World!')"), AgentId("tool_enabled_agent", "default") ) print(response.content) diff --git a/python/samples/tool-use/coding_direct_with_intercept.py b/python/samples/tool-use/coding_direct_with_intercept.py index 1eba7bfa8..044802aba 100644 --- a/python/samples/tool-use/coding_direct_with_intercept.py +++ b/python/samples/tool-use/coding_direct_with_intercept.py @@ -16,7 +16,7 @@ from agnext.components.code_executor import LocalCommandLineCodeExecutor from agnext.components.models import SystemMessage from agnext.components.tool_agent import ToolAgent, ToolException from agnext.components.tools import PythonCodeExecutionTool, Tool -from agnext.core import AgentId +from agnext.core import AgentId, AgentInstantiationContext from agnext.core.intervention import DefaultInterventionHandler, DropMessage sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -48,21 +48,21 @@ async def main() -> None: ) ] # Register agents. - tool_executor_agent = await runtime.register_and_get( + await runtime.register( "tool_executor_agent", lambda: ToolAgent( description="Tool Executor Agent", tools=tools, ), ) - tool_use_agent = await runtime.register_and_get( + await runtime.register( "tool_enabled_agent", lambda: ToolUseAgent( description="Tool Use Agent", system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")], model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), tool_schema=[tool.schema for tool in tools], - tool_agent=tool_executor_agent, + tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key), ), ) @@ -70,7 +70,7 @@ async def main() -> None: # Send a task to the tool user. response = await runtime.send_message( - Message("Run the following Python code: print('Hello, World!')"), tool_use_agent + Message("Run the following Python code: print('Hello, World!')"), AgentId("tool_enabled_agent", "default") ) print(response.content) diff --git a/python/samples/tool-use/coding_pub_sub.py b/python/samples/tool-use/coding_pub_sub.py index 656f1abda..d7203c05d 100644 --- a/python/samples/tool-use/coding_pub_sub.py +++ b/python/samples/tool-use/coding_pub_sub.py @@ -21,6 +21,7 @@ from typing import Dict, List from agnext.application import SingleThreadedAgentRuntime from agnext.components import FunctionCall, TypeRoutedAgent, message_handler +from agnext.components._type_subscription import TypeSubscription from agnext.components.code_executor import LocalCommandLineCodeExecutor from agnext.components.models import ( AssistantMessage, @@ -32,6 +33,7 @@ from agnext.components.models import ( UserMessage, ) from agnext.components.tools import PythonCodeExecutionTool, Tool +from agnext.core import TopicId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -88,7 +90,8 @@ class ToolExecutorAgent(TypeRoutedAgent): session_id=message.session_id, result=FunctionExecutionResult(content=result_as_str, call_id=message.function_call.id), ) - await self.publish_message(task_result) + assert ctx.topic_id is not None + await self.publish_message(task_result, topic_id=ctx.topic_id) class ToolUseAgent(TypeRoutedAgent): @@ -126,7 +129,8 @@ class ToolUseAgent(TypeRoutedAgent): if isinstance(response.content, str): # If the response is a string, just publish the response. response_message = AgentResponse(content=response.content) - await self.publish_message(response_message) + assert ctx.topic_id is not None + await self.publish_message(response_message, topic_id=ctx.topic_id) print(f"AI Response: {response.content}") return @@ -139,7 +143,8 @@ class ToolUseAgent(TypeRoutedAgent): for function_call in response.content: task = ToolExecutionTask(session_id=session_id, function_call=function_call) self._tool_counter[session_id] += 1 - await self.publish_message(task) + assert ctx.topic_id is not None + await self.publish_message(task, topic_id=ctx.topic_id) @message_handler async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None: @@ -165,10 +170,11 @@ class ToolUseAgent(TypeRoutedAgent): self._sessions[message.session_id].append( AssistantMessage(content=response.content, source=self.metadata["type"]) ) + assert ctx.topic_id is not None # If the response is a string, just publish the response. if isinstance(response.content, str): response_message = AgentResponse(content=response.content) - await self.publish_message(response_message) + await self.publish_message(response_message, topic_id=ctx.topic_id) self._tool_results.pop(message.session_id) self._tool_counter.pop(message.session_id) print(f"AI Response: {response.content}") @@ -179,7 +185,7 @@ class ToolUseAgent(TypeRoutedAgent): for function_call in response.content: task = ToolExecutionTask(session_id=message.session_id, function_call=function_call) self._tool_counter[message.session_id] += 1 - await self.publish_message(task) + await self.publish_message(task, topic_id=ctx.topic_id) async def main() -> None: @@ -192,6 +198,7 @@ async def main() -> None: ] # Register agents. await runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools)) + await runtime.add_subscription(TypeSubscription("default", "tool_executor")) await runtime.register( "tool_use_agent", lambda: ToolUseAgent( @@ -201,12 +208,13 @@ async def main() -> None: tools=tools, ), ) + await runtime.add_subscription(TypeSubscription("default", "tool_use_agent")) run_context = runtime.start() # Publish a task. await runtime.publish_message( - UserRequest("Run the following Python code: print('Hello, World!')"), namespace="default" + UserRequest("Run the following Python code: print('Hello, World!')"), topic_id=TopicId("default", "default") ) await run_context.stop_when_idle() diff --git a/python/samples/tool-use/custom_tool_direct.py b/python/samples/tool-use/custom_tool_direct.py index 2bfe6c41a..316c20aa8 100644 --- a/python/samples/tool-use/custom_tool_direct.py +++ b/python/samples/tool-use/custom_tool_direct.py @@ -15,11 +15,13 @@ from agnext.components.models import ( ) from agnext.components.tool_agent import ToolAgent from agnext.components.tools import FunctionTool, Tool +from agnext.core import AgentInstantiationContext from typing_extensions import Annotated sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import AgentId from coding_direct import Message, ToolUseAgent from common.utils import get_chat_completion_client_from_envs @@ -42,23 +44,24 @@ async def main() -> None: ) ] # Register agents. - tool_executor_agent = await runtime.register_and_get( + await runtime.register( "tool_executor_agent", lambda: ToolAgent( description="Tool Executor Agent", tools=tools, ), ) - tool_use_agent = await runtime.register_and_get( + await runtime.register( "tool_enabled_agent", lambda: ToolUseAgent( description="Tool Use Agent", system_messages=[SystemMessage("You are a helpful AI Assistant. Use your tools to solve problems.")], model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"), tool_schema=[tool.schema for tool in tools], - tool_agent=tool_executor_agent, + tool_agent=AgentId("tool_executor_agent", AgentInstantiationContext.current_agent_id().key), ), ) + tool_use_agent = AgentId("tool_enabled_agent", "default") run_context = runtime.start() diff --git a/python/samples/worker/run_worker_pub_sub.py b/python/samples/worker/run_worker_pub_sub.py index e46aa5f21..ef929a2c8 100644 --- a/python/samples/worker/run_worker_pub_sub.py +++ b/python/samples/worker/run_worker_pub_sub.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from agnext.application import WorkerAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext +from agnext.components._type_subscription import TypeSubscription +from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, TopicId @dataclass @@ -38,11 +39,14 @@ class ReceiveAgent(TypeRoutedAgent): @message_handler async def on_greet(self, message: Greeting, ctx: MessageContext) -> None: - await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}")) + assert ctx.topic_id is not None + await self.publish_message(ReturnedGreeting(f"Returned greeting: {message.content}"), topic_id=ctx.topic_id) @message_handler async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None: - await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}")) + assert ctx.topic_id is not None + + await self.publish_message(ReturnedFeedback(f"Returned feedback: {message.content}"), topic_id=ctx.topic_id) class GreeterAgent(TypeRoutedAgent): @@ -51,11 +55,15 @@ class GreeterAgent(TypeRoutedAgent): @message_handler async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None: - await self.publish_message(Greeting(f"Hello, {message.content}!")) + assert ctx.topic_id is not None + + await self.publish_message(Greeting(f"Hello, {message.content}!"), topic_id=ctx.topic_id) @message_handler async def on_returned_greet(self, message: ReturnedGreeting, ctx: MessageContext) -> None: - await self.publish_message(Feedback(f"Feedback: {message.content}")) + assert ctx.topic_id is not None + + await self.publish_message(Feedback(f"Feedback: {message.content}"), topic_id=ctx.topic_id) async def main() -> None: @@ -68,9 +76,11 @@ async def main() -> None: await runtime.start(host_connection_string="localhost:50051") await runtime.register("reciever", lambda: ReceiveAgent()) + await runtime.add_subscription(TypeSubscription("default", "reciever")) await runtime.register("greeter", lambda: GreeterAgent()) + await runtime.add_subscription(TypeSubscription("default", "greeter")) - await runtime.publish_message(AskToGreet("Hello World!"), namespace="default") + await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default")) # Just to keep the runtime running try: diff --git a/python/samples/worker/run_worker_rpc.py b/python/samples/worker/run_worker_rpc.py index 2168182b7..52c065a3e 100644 --- a/python/samples/worker/run_worker_rpc.py +++ b/python/samples/worker/run_worker_rpc.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from agnext.application import WorkerAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, MessageContext +from agnext.core import MESSAGE_TYPE_REGISTRY, AgentId, AgentInstantiationContext, MessageContext, TopicId @dataclass @@ -43,7 +43,8 @@ class GreeterAgent(TypeRoutedAgent): @message_handler async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None: response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id) - await self.publish_message(Feedback(f"Feedback: {response.content}")) + assert ctx.topic_id is not None + await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=ctx.topic_id) async def main() -> None: @@ -54,10 +55,11 @@ async def main() -> None: await runtime.start(host_connection_string="localhost:50051") await runtime.register("reciever", lambda: ReceiveAgent()) - reciever = await runtime.get("reciever") - await runtime.register("greeter", lambda: GreeterAgent(reciever)) + await runtime.register( + "greeter", lambda: GreeterAgent(AgentId("reciever", AgentInstantiationContext.current_agent_id().key)) + ) - await runtime.publish_message(AskToGreet("Hello World!"), namespace="default") + await runtime.publish_message(AskToGreet("Hello World!"), topic_id=TopicId("default", "default")) # Just to keep the runtime running try: diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index 6a44cb573..5cfd3e5d4 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -12,13 +12,13 @@ from dataclasses import dataclass from enum import Enum from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast +from agnext.core import Subscription, TopicId + from ..core import ( - MESSAGE_TYPE_REGISTRY, Agent, AgentId, AgentInstantiationContext, AgentMetadata, - AgentProxy, AgentRuntime, CancellationToken, MessageContext, @@ -38,7 +38,7 @@ class PublishMessageEnvelope: message: Any cancellation_token: CancellationToken sender: AgentId | None - namespace: str + topic_id: TopicId @dataclass(kw_only=True) @@ -124,16 +124,18 @@ class SingleThreadedAgentRuntime(AgentRuntime): def __init__(self, *, intervention_handler: InterventionHandler | None = None) -> None: self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] # (namespace, type) -> List[AgentId] - self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set) self._agent_factories: Dict[ str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] ] = {} self._instantiated_agents: Dict[AgentId, Agent] = {} self._intervention_handler = intervention_handler - self._known_namespaces: set[str] = set() self._outstanding_tasks = Counter() self._background_tasks: Set[Task[Any]] = set() + self._subscriptions: List[Subscription] = [] + self._seen_topics: Set[TopicId] = set() + self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list) + @property def unprocessed_messages( self, @@ -177,8 +179,6 @@ class SingleThreadedAgentRuntime(AgentRuntime): if sender is not None and sender.key != recipient.key: raise ValueError("Sender and recipient must be in the same namespace to communicate.") - await self._process_seen_namespace(recipient.key) - content = message.__dict__ if hasattr(message, "__dict__") else message logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") @@ -199,8 +199,8 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def publish_message( self, message: Any, + topic_id: TopicId, *, - namespace: str | None = None, sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> None: @@ -219,26 +219,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): # ) # ) - if sender is None and namespace is None: - raise ValueError("Namespace must be provided if sender is not provided.") - - sender_namespace = sender.key if sender is not None else None - explicit_namespace = namespace - if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace: - raise ValueError( - f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}" - ) - - assert explicit_namespace is not None or sender_namespace is not None - namespace = cast(str, explicit_namespace or sender_namespace) - await self._process_seen_namespace(namespace) - self._message_queue.append( PublishMessageEnvelope( - message=message, - cancellation_token=cancellation_token, - sender=sender, - namespace=namespace, + message=message, cancellation_token=cancellation_token, sender=sender, topic_id=topic_id ) ) @@ -300,12 +283,13 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._outstanding_tasks.decrement() async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: + self._build_for_new_topic(message_envelope.topic_id) responses: List[Awaitable[Any]] = [] - target_namespace = message_envelope.namespace - for agent_id in self._per_type_subscribers[ - (target_namespace, MESSAGE_TYPE_REGISTRY.type_name(message_envelope.message)) - ]: - if message_envelope.sender is not None and agent_id.type == message_envelope.sender.type: + + recipients = self._subscribed_recipients[message_envelope.topic_id] + for agent_id in recipients: + # Avoid sending the message back to the sender + if message_envelope.sender is not None and agent_id == message_envelope.sender: continue sender_agent = ( @@ -326,8 +310,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): # ) message_context = MessageContext( sender=message_envelope.sender, - # TODO: topic_id - topic_id=None, + topic_id=message_envelope.topic_id, is_rpc=False, cancellation_token=message_envelope.cancellation_token, ) @@ -460,16 +443,12 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def register( self, - name: str, + type: str, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], ) -> None: - if name in self._agent_factories: - raise ValueError(f"Agent with name {name} already exists.") - self._agent_factories[name] = agent_factory - - # For all already prepared namespaces we need to prepare this agent - for namespace in self._known_namespaces: - await self._get_agent(AgentId(type=name, key=namespace)) + if type in self._agent_factories: + raise ValueError(f"Agent with type {type} already exists.") + self._agent_factories[type] = agent_factory async def _invoke_agent_factory( self, @@ -496,7 +475,6 @@ class SingleThreadedAgentRuntime(AgentRuntime): return agent async def _get_agent(self, agent_id: AgentId) -> Agent: - await self._process_seen_namespace(agent_id.key) if agent_id in self._instantiated_agents: return self._instantiated_agents[agent_id] @@ -504,20 +482,10 @@ class SingleThreadedAgentRuntime(AgentRuntime): raise LookupError(f"Agent with name {agent_id.type} not found.") agent_factory = self._agent_factories[agent_id.type] - agent = await self._invoke_agent_factory(agent_factory, agent_id) - for message_type in agent.metadata["subscriptions"]: - self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id) self._instantiated_agents[agent_id] = agent return agent - async def get(self, name: str, *, namespace: str = "default") -> AgentId: - return (await self._get_agent(AgentId(type=name, key=namespace))).id - - async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: - id = await self.get(name, namespace=namespace) - return AgentProxy(id, self) - # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] if id.type not in self._agent_factories: @@ -531,12 +499,40 @@ class SingleThreadedAgentRuntime(AgentRuntime): return agent_instance - # Hydrate the agent instances in a namespace. The primary reason for this is - # to ensure message type subscriptions are set up. - async def _process_seen_namespace(self, namespace: str) -> None: - if namespace in self._known_namespaces: + async def add_subscription(self, subscription: Subscription) -> None: + # Check if the subscription already exists + if any(sub.id == subscription.id for sub in self._subscriptions): + raise ValueError("Subscription already exists") + + if len(self._seen_topics) > 0: + raise NotImplementedError("Cannot add subscription after topics have been seen yet") + + self._subscriptions.append(subscription) + + async def remove_subscription(self, id: str) -> None: + # Check if the subscription exists + if not any(sub.id == id for sub in self._subscriptions): + raise ValueError("Subscription does not exist") + + def is_not_sub(x: Subscription) -> bool: + return x.id != id + + self._subscriptions = list(filter(is_not_sub, self._subscriptions)) + + # Rebuild the subscriptions + self._rebuild_subscriptions(self._seen_topics) + + # TODO: optimize this... + def _rebuild_subscriptions(self, topics: Set[TopicId]) -> None: + self._subscribed_recipients.clear() + for topic in topics: + self._build_for_new_topic(topic) + + def _build_for_new_topic(self, topic: TopicId) -> None: + if topic in self._seen_topics: return - self._known_namespaces.add(namespace) - for name in self._known_agent_names: - await self._get_agent(AgentId(type=name, key=namespace)) + self._seen_topics.add(topic) + for subscription in self._subscriptions: + if subscription.is_match(topic): + self._subscribed_recipients[topic].append(subscription.map_to_agent(topic)) diff --git a/python/src/agnext/application/_worker_runtime.py b/python/src/agnext/application/_worker_runtime.py index 556f0a7bf..42e58bc63 100644 --- a/python/src/agnext/application/_worker_runtime.py +++ b/python/src/agnext/application/_worker_runtime.py @@ -28,9 +28,9 @@ import grpc from grpc.aio import StreamStreamCall from typing_extensions import Self -from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext +from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, Subscription, TopicId -from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentProxy, AgentRuntime, CancellationToken +from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken from .protos import AgentId as AgentIdProto from .protos import ( AgentRpcStub, @@ -153,6 +153,9 @@ class WorkerAgentRuntime(AgentRuntime): self._next_request_id = 0 self._host_connection: HostConnection | None = None self._background_tasks: Set[Task[Any]] = set() + self._subscriptions: List[Subscription] = [] + self._seen_topics: Set[TopicId] = set() + self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list) async def start(self, host_connection_string: str) -> None: if self._running: @@ -245,29 +248,25 @@ class WorkerAgentRuntime(AgentRuntime): async def publish_message( self, message: Any, + topic_id: TopicId, *, - namespace: str | None = None, sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> None: - if not self._running: - raise ValueError("Runtime must be running when publishing message.") assert self._host_connection is not None - sender_namespace = sender.key if sender is not None else None - explicit_namespace = namespace - if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace: - raise ValueError( - f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}" - ) - assert explicit_namespace is not None or sender_namespace is not None - actual_namespace = cast(str, explicit_namespace or sender_namespace) - await self._process_seen_namespace(actual_namespace) message_type = MESSAGE_TYPE_REGISTRY.type_name(message) serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type) - message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message)) - task = asyncio.create_task(self._host_connection.send(message)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + message = Message( + event=Event( + topic_type=topic_id.type, topic_source=topic_id.source, data_type=message_type, data=serialized_message + ) + ) + + async def write_message() -> None: + assert self._host_connection is not None + await self._host_connection.send(message) + + await asyncio.create_task(write_message()) async def save_state(self) -> Mapping[str, Any]: raise NotImplementedError("Saving state is not yet implemented.") @@ -284,26 +283,6 @@ class WorkerAgentRuntime(AgentRuntime): async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: raise NotImplementedError("Agent load_state is not yet implemented.") - async def register( - self, - name: str, - agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], - ) -> None: - if not self._running: - raise ValueError("Runtime must be running when registering agent.") - if name in self._agent_factories: - raise ValueError(f"Agent with name {name} already exists.") - self._agent_factories[name] = agent_factory - - # For all already prepared namespaces we need to prepare this agent - for namespace in self._known_namespaces: - await self._get_agent(AgentId(type=name, key=namespace)) - - assert self._host_connection is not None - message = Message(registerAgentType=RegisterAgentType(type=name)) - await self._host_connection.send(message) - logger.info("Sent registerAgentType message for %s", name) - async def _process_request(self, request: RpcRequest) -> None: assert self._host_connection is not None target = AgentId(request.target.name, request.target.namespace) @@ -347,27 +326,41 @@ class WorkerAgentRuntime(AgentRuntime): future.set_result(response.result) async def _process_event(self, event: Event) -> None: - message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type) - namespace = event.namespace - responses: List[Awaitable[Any]] = [] - for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]: - # TODO: skip the sender? - message_context = MessageContext( - sender=None, - topic_id=None, - is_rpc=False, - cancellation_token=CancellationToken(), - ) - agent = await self._get_agent(agent_id) - future = agent.on_message(message, ctx=message_context) - responses.append(future) + ... + # message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.data_type) - try: - _ = await asyncio.gather(*responses) - except BaseException as e: - if isinstance(e, asyncio.CancelledError): - return - event_logger.error("Error handling event message", exc_info=e) + # for agent_id in self._per_type_subscribers[ + # (namespace, MESSAGE_TYPE_REGISTRY.type_name(message)) + # ]: + + # agent = await self._get_agent(agent_id) + # message_context = MessageContext( + # # TODO: should sender be in the proto even for published events? + # sender=None, + # # TODO: topic_id + # topic_id=None, + # is_rpc=False, + # cancellation_token=CancellationToken(), + # ) + # try: + # await agent.on_message(message, ctx=message_context) + # logger.info("%s handled event %s", agent_id, message) + # except Exception as e: + # event_logger.error("Error handling message", exc_info=e) + + async def register( + self, + type: str, + agent_factory: Callable[[], T | Awaitable[T]], + ) -> None: + if type in self._agent_factories: + raise ValueError(f"Agent with type {type} already exists.") + self._agent_factories[type] = agent_factory + + assert self._host_connection is not None + message = Message(registerAgentType=RegisterAgentType(type=type)) + await self._host_connection.send(message) + logger.info("Sent registerAgentType message for %s", type) async def _invoke_agent_factory( self, @@ -394,7 +387,6 @@ class WorkerAgentRuntime(AgentRuntime): return agent async def _get_agent(self, agent_id: AgentId) -> Agent: - await self._process_seen_namespace(agent_id.key) if agent_id in self._instantiated_agents: return self._instantiated_agents[agent_id] @@ -402,32 +394,16 @@ class WorkerAgentRuntime(AgentRuntime): raise ValueError(f"Agent with name {agent_id.type} not found.") agent_factory = self._agent_factories[agent_id.type] - agent = await self._invoke_agent_factory(agent_factory, agent_id) - - for message_type in agent.metadata["subscriptions"]: - self._per_type_subscribers[(agent_id.key, message_type)].add(agent_id) - self._instantiated_agents[agent_id] = agent return agent - async def get(self, name: str, *, namespace: str = "default") -> AgentId: - return (await self._get_agent(AgentId(type=name, key=namespace))).id - - async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: - id = await self.get(name, namespace=namespace) - return AgentProxy(id, self) - # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] raise NotImplementedError("try_get_underlying_agent_instance is not yet implemented.") - # Hydrate the agent instances in a namespace. The primary reason for this is - # to ensure message type subscriptions are set up. - async def _process_seen_namespace(self, namespace: str) -> None: - if namespace in self._known_namespaces: - return + async def add_subscription(self, subscription: Subscription) -> None: + raise NotImplementedError("Subscriptions are not yet implemented.") - self._known_namespaces.add(namespace) - for name in self._known_agent_names: - await self._get_agent(AgentId(type=name, key=namespace)) + async def remove_subscription(self, id: str) -> None: + raise NotImplementedError("Subscriptions are not yet implemented.") diff --git a/python/src/agnext/application/protos/agent_worker_pb2.py b/python/src/agnext/application/protos/agent_worker_pb2.py index 075028366..36c3968cc 100644 --- a/python/src/agnext/application/protos/agent_worker_pb2.py +++ b/python/src/agnext/application/protos/agent_worker_pb2.py @@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x96\x01\n\x05\x45vent\x12\x11\n\tnamespace\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\t\x12-\n\x08metadata\x18\x04 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\"*\n\x07\x41gentId\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnamespace\x18\x02 \x01(\t\"\xe5\x01\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x1f\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentId\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x01(\t\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa6\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xb2\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12\x11\n\tdata_type\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"!\n\x11RegisterAgentType\x12\x0c\n\x04type\x18\x01 \x01(\t\"\xbc\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x36\n\x11registerAgentType\x18\x04 \x01(\x0b\x32\x19.agents.RegisterAgentTypeH\x00\x42\t\n\x07message2?\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -38,13 +38,13 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=257 _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=304 _globals['_EVENT']._serialized_start=476 - _globals['_EVENT']._serialized_end=626 + _globals['_EVENT']._serialized_end=654 _globals['_EVENT_METADATAENTRY']._serialized_start=257 _globals['_EVENT_METADATAENTRY']._serialized_end=304 - _globals['_REGISTERAGENTTYPE']._serialized_start=628 - _globals['_REGISTERAGENTTYPE']._serialized_end=661 - _globals['_MESSAGE']._serialized_start=664 - _globals['_MESSAGE']._serialized_end=852 - _globals['_AGENTRPC']._serialized_start=854 - _globals['_AGENTRPC']._serialized_end=917 + _globals['_REGISTERAGENTTYPE']._serialized_start=656 + _globals['_REGISTERAGENTTYPE']._serialized_end=689 + _globals['_MESSAGE']._serialized_start=692 + _globals['_MESSAGE']._serialized_end=880 + _globals['_AGENTRPC']._serialized_start=882 + _globals['_AGENTRPC']._serialized_end=945 # @@protoc_insertion_point(module_scope) diff --git a/python/src/agnext/application/protos/agent_worker_pb2.pyi b/python/src/agnext/application/protos/agent_worker_pb2.pyi index b7dc151ba..4e36c9c73 100644 --- a/python/src/agnext/application/protos/agent_worker_pb2.pyi +++ b/python/src/agnext/application/protos/agent_worker_pb2.pyi @@ -14,8 +14,6 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor @typing.final class AgentId(google.protobuf.message.Message): - """TODO: update""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor NAME_FIELD_NUMBER: builtins.int @@ -143,24 +141,27 @@ class Event(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - NAMESPACE_FIELD_NUMBER: builtins.int - TYPE_FIELD_NUMBER: builtins.int + TOPIC_TYPE_FIELD_NUMBER: builtins.int + TOPIC_SOURCE_FIELD_NUMBER: builtins.int + DATA_TYPE_FIELD_NUMBER: builtins.int DATA_FIELD_NUMBER: builtins.int METADATA_FIELD_NUMBER: builtins.int - namespace: builtins.str - type: builtins.str + topic_type: builtins.str + topic_source: builtins.str + data_type: builtins.str data: builtins.str @property def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... def __init__( self, *, - namespace: builtins.str = ..., - type: builtins.str = ..., + topic_type: builtins.str = ..., + topic_source: builtins.str = ..., + data_type: builtins.str = ..., data: builtins.str = ..., metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "metadata", b"metadata", "namespace", b"namespace", "type", b"type"]) -> None: ... + def ClearField(self, field_name: typing.Literal["data", b"data", "data_type", b"data_type", "metadata", b"metadata", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ... global___Event = Event diff --git a/python/src/agnext/components/_closure_agent.py b/python/src/agnext/components/_closure_agent.py index 21ffd9511..f048db669 100644 --- a/python/src/agnext/components/_closure_agent.py +++ b/python/src/agnext/components/_closure_agent.py @@ -72,7 +72,6 @@ class ClosureAgent(Agent): key=self._id.key, type=self._id.type, description=self._description, - subscriptions=self._subscriptions, ) @property diff --git a/python/src/agnext/components/_type_routed_agent.py b/python/src/agnext/components/_type_routed_agent.py index ba5ee5384..0232f2716 100644 --- a/python/src/agnext/components/_type_routed_agent.py +++ b/python/src/agnext/components/_type_routed_agent.py @@ -142,12 +142,12 @@ class TypeRoutedAgent(BaseAgent): message_handler = cast(MessageHandler[Any, Any], handler) for target_type in message_handler.target_types: self._handlers[target_type] = message_handler - subscriptions = list(self._handlers.keys()) + for message_type in self._handlers.keys(): if not MESSAGE_TYPE_REGISTRY.is_registered(MESSAGE_TYPE_REGISTRY.type_name(message_type)): MESSAGE_TYPE_REGISTRY.add_type(message_type) - subscriptions_str = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscriptions] - super().__init__(description, subscriptions_str) + + super().__init__(description) async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: key_type: Type[Any] = type(message) # type: ignore diff --git a/python/src/agnext/components/_type_subscription.py b/python/src/agnext/components/_type_subscription.py index 437eabb7b..47e292a74 100644 --- a/python/src/agnext/components/_type_subscription.py +++ b/python/src/agnext/components/_type_subscription.py @@ -42,5 +42,4 @@ class TypeSubscription(Subscription): if not self.is_match(topic_id): raise CantHandleException("TopicId does not match the subscription") - # TODO: Update agentid to reflect agent type and key return AgentId(type=self._agent_type, key=topic_id.source) diff --git a/python/src/agnext/core/__init__.py b/python/src/agnext/core/__init__.py index 1ce9345b9..851a4f347 100644 --- a/python/src/agnext/core/__init__.py +++ b/python/src/agnext/core/__init__.py @@ -12,7 +12,7 @@ from ._agent_runtime import AgentRuntime from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken from ._message_context import MessageContext -from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer +from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer from ._subscription import Subscription from ._topic import TopicId @@ -32,4 +32,5 @@ __all__ = [ "TopicId", "Subscription", "MessageContext", + "Serialization", ] diff --git a/python/src/agnext/core/_agent_metadata.py b/python/src/agnext/core/_agent_metadata.py index b4e8db482..abdf92035 100644 --- a/python/src/agnext/core/_agent_metadata.py +++ b/python/src/agnext/core/_agent_metadata.py @@ -1,8 +1,7 @@ -from typing import Sequence, TypedDict +from typing import TypedDict class AgentMetadata(TypedDict): type: str key: str description: str - subscriptions: Sequence[str] diff --git a/python/src/agnext/core/_agent_runtime.py b/python/src/agnext/core/_agent_runtime.py index 617b9ca62..0f05f2f2d 100644 --- a/python/src/agnext/core/_agent_runtime.py +++ b/python/src/agnext/core/_agent_runtime.py @@ -5,8 +5,9 @@ from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, r from ._agent import Agent from ._agent_id import AgentId from ._agent_metadata import AgentMetadata -from ._agent_proxy import AgentProxy from ._cancellation_token import CancellationToken +from ._subscription import Subscription +from ._topic import TopicId # Undeliverable - error @@ -45,8 +46,8 @@ class AgentRuntime(Protocol): async def publish_message( self, message: Any, + topic_id: TopicId, *, - namespace: str | None = None, sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> None: @@ -56,23 +57,24 @@ class AgentRuntime(Protocol): Args: message (Any): The message to publish. - namespace (str | None, optional): The namespace to publish to. Defaults to None. + topic (TopicId): The topic to publish the message to. sender (AgentId | None, optional): The agent which sent the message. Defaults to None. cancellation_token (CancellationToken | None, optional): Token used to cancel an in progress . Defaults to None. Raises: UndeliverableException: If the message cannot be delivered. """ + ... async def register( self, - name: str, + type: str, agent_factory: Callable[[], T | Awaitable[T]], ) -> None: - """Register an agent factory with the runtime associated with a specific name. The name must be unique. + """Register an agent factory with the runtime associated with a specific type. The type must be unique. Args: - name (str): The name of the type agent this factory creates. + type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes. agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID. @@ -93,30 +95,6 @@ class AgentRuntime(Protocol): ... - async def get(self, name: str, *, namespace: str = "default") -> AgentId: - """Get an agent by name and namespace. - - Args: - name (str): The name of the agent. - namespace (str, optional): The namespace of the agent. Defaults to "default". - - Returns: - AgentId: The agent id. - """ - ... - - async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: - """Get a proxy for an agent by name and namespace. - - Args: - name (str): The name of the agent. - namespace (str, optional): The namespace of the agent. Defaults to "default". - - Returns: - AgentProxy: The agent proxy. - """ - ... - # TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737 async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment] """Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases. @@ -137,46 +115,6 @@ class AgentRuntime(Protocol): """ ... - async def register_and_get( - self, - name: str, - agent_factory: Callable[[], T | Awaitable[T]], - *, - namespace: str = "default", - ) -> AgentId: - """Register an agent factory with the runtime associated with a specific name and get the agent id. The name must be unique. - - Args: - name (str): The name of the type agent this factory creates. - agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `agnext.core.AgentInstantiationContext` to access variables like the current runtime and agent ID. - namespace (str, optional): The namespace of the agent. Defaults to "default". - - Returns: - AgentId: The agent id. - """ - await self.register(name, agent_factory) - return await self.get(name, namespace=namespace) - - async def register_and_get_proxy( - self, - name: str, - agent_factory: Callable[[], T | Awaitable[T]], - *, - namespace: str = "default", - ) -> AgentProxy: - """Register an agent factory with the runtime associated with a specific name and get the agent proxy. The name must be unique. - - Args: - name (str): The name of the type agent this factory creates. - agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. - namespace (str, optional): The namespace of the agent. Defaults to "default". - - Returns: - AgentProxy: The agent proxy. - """ - await self.register(name, agent_factory) - return await self.get_proxy(name, namespace=namespace) - async def save_state(self) -> Mapping[str, Any]: """Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`. @@ -227,3 +165,22 @@ class AgentRuntime(Protocol): state (Mapping[str, Any]): The saved state. """ ... + + async def add_subscription(self, subscription: Subscription) -> None: + """Add a new subscription that the runtime should fulfill when processing published messages + + Args: + subscription (Subscription): The subscription to add + """ + ... + + async def remove_subscription(self, id: str) -> None: + """Remove a subscription from the runtime + + Args: + id (str): id of the subscription to remove + + Raises: + LookupError: If the subscription does not exist + """ + ... diff --git a/python/src/agnext/core/_base_agent.py b/python/src/agnext/core/_base_agent.py index 348ed9559..b5239faf9 100644 --- a/python/src/agnext/core/_base_agent.py +++ b/python/src/agnext/core/_base_agent.py @@ -1,6 +1,6 @@ import warnings from abc import ABC, abstractmethod -from typing import Any, Mapping, Sequence +from typing import Any, Mapping from ._agent import Agent from ._agent_id import AgentId @@ -9,20 +9,16 @@ from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._cancellation_token import CancellationToken from ._message_context import MessageContext +from ._topic import TopicId class BaseAgent(ABC, Agent): @property def metadata(self) -> AgentMetadata: assert self._id is not None - return AgentMetadata( - key=self._id.key, - type=self._id.type, - description=self._description, - subscriptions=self._subscriptions, - ) + return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) - def __init__(self, description: str, subscriptions: Sequence[str]) -> None: + def __init__(self, description: str) -> None: try: runtime = AgentInstantiationContext.current_runtime() id = AgentInstantiationContext.current_agent_id() @@ -36,7 +32,6 @@ class BaseAgent(ABC, Agent): if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description - self._subscriptions = subscriptions @property def type(self) -> str: @@ -74,10 +69,11 @@ class BaseAgent(ABC, Agent): async def publish_message( self, message: Any, + topic_id: TopicId, *, cancellation_token: CancellationToken | None = None, ) -> None: - await self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token) + await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token) def save_state(self) -> Mapping[str, Any]: warnings.warn("save_state not implemented", stacklevel=2) diff --git a/python/src/agnext/core/_subscription.py b/python/src/agnext/core/_subscription.py index a75492d71..d606fc176 100644 --- a/python/src/agnext/core/_subscription.py +++ b/python/src/agnext/core/_subscription.py @@ -1,10 +1,11 @@ -from typing import Protocol +from typing import Protocol, runtime_checkable -from agnext.core._agent_id import AgentId +from agnext.core import AgentId from ._topic import TopicId +@runtime_checkable class Subscription(Protocol): """Subscriptions define the topics that an agent is interested in.""" @@ -19,6 +20,20 @@ class Subscription(Protocol): """ ... + def __eq__(self, other: object) -> bool: + """Check if two subscriptions are equal. + + Args: + other (object): Other subscription to compare against. + + Returns: + bool: True if the subscriptions are equal, False otherwise. + """ + if not isinstance(other, Subscription): + return False + + return self.id == other.id + def is_match(self, topic_id: TopicId) -> bool: """Check if a given topic_id matches the subscription. diff --git a/python/src/agnext/core/_topic.py b/python/src/agnext/core/_topic.py index f9307d60c..f693a3eb4 100644 --- a/python/src/agnext/core/_topic.py +++ b/python/src/agnext/core/_topic.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -@dataclass +@dataclass(eq=True, frozen=True) class TopicId: type: str """Type of the event that this topic_id contains. Adhere's to the cloud event spec. diff --git a/python/teams/team-one/examples/example.py b/python/teams/team-one/examples/example.py index 5bab59eca..09adffc0c 100644 --- a/python/teams/team-one/examples/example.py +++ b/python/teams/team-one/examples/example.py @@ -3,6 +3,7 @@ import logging from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME +from agnext.core import AgentId, AgentProxy from team_one.agents.coder import Coder, Executor from team_one.agents.orchestrator import LedgerOrchestrator from team_one.agents.user_proxy import UserProxy @@ -15,18 +16,22 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register agents. - coder = await runtime.register_and_get_proxy( + await runtime.register( "Coder", lambda: Coder(model_client=create_completion_client_from_env()), ) + coder = AgentProxy(AgentId("Coder", "default"), runtime) - executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code")) + await runtime.register("Executor", lambda: Executor("A agent for executing code")) + executor = AgentProxy(AgentId("Executor", "default"), runtime) - user_proxy = await runtime.register_and_get_proxy( + await runtime.register( "UserProxy", lambda: UserProxy(description="The current user interacting with you."), ) + user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime) + # TODO: doesn't work for more than default key await runtime.register( "orchestrator", lambda: LedgerOrchestrator( diff --git a/python/teams/team-one/examples/example_coder.py b/python/teams/team-one/examples/example_coder.py index d45db0999..9b45539fa 100644 --- a/python/teams/team-one/examples/example_coder.py +++ b/python/teams/team-one/examples/example_coder.py @@ -3,6 +3,7 @@ import logging from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME +from agnext.core import AgentId, AgentProxy from team_one.agents.coder import Coder, Executor from team_one.agents.orchestrator import RoundRobinOrchestrator from team_one.agents.user_proxy import UserProxy @@ -15,17 +16,20 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register agents. - coder = await runtime.register_and_get_proxy( + await runtime.register( "Coder", lambda: Coder(model_client=create_completion_client_from_env()), ) + coder = AgentProxy(AgentId("Coder", "default"), runtime) - executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code")) + await runtime.register("Executor", lambda: Executor("A agent for executing code")) + executor = AgentProxy(AgentId("Executor", "default"), runtime) - user_proxy = await runtime.register_and_get_proxy( + await runtime.register( "UserProxy", - lambda: UserProxy(), + lambda: UserProxy(description="The current user interacting with you."), ) + user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime) await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy])) diff --git a/python/teams/team-one/examples/example_file_surfer.py b/python/teams/team-one/examples/example_file_surfer.py index c5d00c8cc..0318a5654 100644 --- a/python/teams/team-one/examples/example_file_surfer.py +++ b/python/teams/team-one/examples/example_file_surfer.py @@ -3,6 +3,7 @@ import logging from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME +from agnext.core import AgentId, AgentProxy from team_one.agents.file_surfer import FileSurfer from team_one.agents.orchestrator import RoundRobinOrchestrator from team_one.agents.user_proxy import UserProxy @@ -18,14 +19,17 @@ async def main() -> None: client = create_completion_client_from_env() # Register agents. - file_surfer = await runtime.register_and_get_proxy( + await runtime.register( "file_surfer", lambda: FileSurfer(model_client=client), ) - user_proxy = await runtime.register_and_get_proxy( + file_surfer = AgentProxy(AgentId("file_surfer", "default"), runtime) + + await runtime.register( "UserProxy", lambda: UserProxy(), ) + user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime) await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy])) diff --git a/python/teams/team-one/examples/example_reflexagents.py b/python/teams/team-one/examples/example_reflexagents.py index 772c6ee84..88c05c942 100644 --- a/python/teams/team-one/examples/example_reflexagents.py +++ b/python/teams/team-one/examples/example_reflexagents.py @@ -4,6 +4,7 @@ import logging from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME from agnext.components.models import UserMessage +from agnext.core import AgentId, AgentProxy, TopicId from team_one.agents.orchestrator import RoundRobinOrchestrator from team_one.agents.reflex_agents import ReflexAgent from team_one.messages import BroadcastMessage @@ -13,14 +14,19 @@ from team_one.utils import LogHandler async def main() -> None: runtime = SingleThreadedAgentRuntime() - fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent")) - fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent")) - fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent")) - await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3])) + await runtime.register("fake_agent_1", lambda: ReflexAgent("First reflect agent")) + fake1 = AgentProxy(AgentId("fake_agent_1", "default"), runtime) + await runtime.register("fake_agent_2", lambda: ReflexAgent("Second reflect agent")) + fake2 = AgentProxy(AgentId("fake_agent_2", "default"), runtime) + + await runtime.register("fake_agent_3", lambda: ReflexAgent("Third reflect agent")) + fake3 = AgentProxy(AgentId("fake_agent_3", "default"), runtime) + + await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3])) task_message = UserMessage(content="Test Message", source="User") run_context = runtime.start() - await runtime.publish_message(BroadcastMessage(task_message), namespace="default") + await runtime.publish_message(BroadcastMessage(task_message), topic_id=TopicId("default", "default")) await run_context.stop_when_idle() diff --git a/python/teams/team-one/examples/example_userproxy.py b/python/teams/team-one/examples/example_userproxy.py index 9586f6b1a..dcce89e44 100644 --- a/python/teams/team-one/examples/example_userproxy.py +++ b/python/teams/team-one/examples/example_userproxy.py @@ -4,6 +4,7 @@ import logging # from typing import Any, Dict, List, Tuple, Union from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME +from agnext.core import AgentId, AgentProxy from team_one.agents.coder import Coder from team_one.agents.orchestrator import RoundRobinOrchestrator from team_one.agents.user_proxy import UserProxy @@ -19,14 +20,17 @@ async def main() -> None: client = create_completion_client_from_env() # Register agents. - coder = await runtime.register_and_get_proxy( + await runtime.register( "Coder", lambda: Coder(model_client=client), ) - user_proxy = await runtime.register_and_get_proxy( + coder = AgentProxy(AgentId("Coder", "default"), runtime) + + await runtime.register( "UserProxy", lambda: UserProxy(), ) + user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime) await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy])) diff --git a/python/teams/team-one/examples/example_websurfer.py b/python/teams/team-one/examples/example_websurfer.py index 26c154b95..c0bcebefc 100644 --- a/python/teams/team-one/examples/example_websurfer.py +++ b/python/teams/team-one/examples/example_websurfer.py @@ -4,6 +4,7 @@ import os from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME +from agnext.core import AgentId, AgentProxy from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer from team_one.agents.orchestrator import RoundRobinOrchestrator from team_one.agents.user_proxy import UserProxy @@ -21,15 +22,17 @@ async def main() -> None: client = create_completion_client_from_env() # Register agents. - web_surfer = await runtime.register_and_get_proxy( + await runtime.register( "WebSurfer", lambda: MultimodalWebSurfer(), ) + web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime) - user_proxy = await runtime.register_and_get_proxy( + await runtime.register( "UserProxy", lambda: UserProxy(), ) + user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime) await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy])) diff --git a/python/teams/team-one/src/team_one/agents/base_worker.py b/python/teams/team-one/src/team_one/agents/base_worker.py index a3e089c09..6f8b68884 100644 --- a/python/teams/team-one/src/team_one/agents/base_worker.py +++ b/python/teams/team-one/src/team_one/agents/base_worker.py @@ -5,7 +5,7 @@ from agnext.components.models import ( LLMMessage, UserMessage, ) -from agnext.core import CancellationToken, MessageContext +from agnext.core import CancellationToken, MessageContext, TopicId from team_one.messages import ( BroadcastMessage, @@ -45,7 +45,8 @@ class BaseWorker(TeamOneBaseAgent): self._chat_history.append(assistant_message) user_message = UserMessage(content=response, source=self.metadata["type"]) - await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt)) + topic_id = TopicId("default", self.id.key) + await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id) async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]: """Returns (request_halt, response_message)""" diff --git a/python/teams/team-one/src/team_one/agents/orchestrator.py b/python/teams/team-one/src/team_one/agents/orchestrator.py index fe5f9a13c..1d338b875 100644 --- a/python/teams/team-one/src/team_one/agents/orchestrator.py +++ b/python/teams/team-one/src/team_one/agents/orchestrator.py @@ -2,7 +2,7 @@ import json from typing import Any, Dict, List, Optional from agnext.components.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage -from agnext.core import AgentProxy +from agnext.core import AgentProxy, TopicId from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage from .base_orchestrator import BaseOrchestrator, logger @@ -248,8 +248,10 @@ class LedgerOrchestrator(BaseOrchestrator): synthesized_prompt = self._get_synthesize_prompt( self._task, self._team_description, self._facts, self._plan ) + topic_id = TopicId("default", self.id.key) await self.publish_message( - BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])) + BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])), + topic_id=topic_id, ) logger.info( @@ -319,14 +321,17 @@ class LedgerOrchestrator(BaseOrchestrator): # Reset everyone, then rebroadcast the new plan self._chat_history = [self._chat_history[0]] - await self.publish_message(ResetMessage()) + topic_id = TopicId("default", self.id.key) + await self.publish_message(ResetMessage(), topic_id=topic_id) # Send everyone the NEW plan synthesized_prompt = self._get_synthesize_prompt( self._task, self._team_description, self._facts, self._plan ) + topic_id = TopicId("default", self.id.key) await self.publish_message( - BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])) + BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])), + topic_id=topic_id, ) logger.info( @@ -351,8 +356,10 @@ class LedgerOrchestrator(BaseOrchestrator): assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"]) logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction)) self._chat_history.append(assistant_message) # My copy + topic_id = TopicId("default", self.id.key) await self.publish_message( - BroadcastMessage(content=user_message, request_halt=False) + BroadcastMessage(content=user_message, request_halt=False), + topic_id=topic_id, ) # Send to everyone else return agent diff --git a/python/teams/team-one/src/team_one/agents/reflex_agents.py b/python/teams/team-one/src/team_one/agents/reflex_agents.py index cd0df896e..bdd4511db 100644 --- a/python/teams/team-one/src/team_one/agents/reflex_agents.py +++ b/python/teams/team-one/src/team_one/agents/reflex_agents.py @@ -1,6 +1,6 @@ from agnext.components import TypeRoutedAgent, message_handler from agnext.components.models import UserMessage -from agnext.core import MessageContext +from agnext.core import MessageContext, TopicId from ..messages import BroadcastMessage, RequestReplyMessage @@ -22,5 +22,6 @@ class ReflexAgent(TypeRoutedAgent): content=f"Hello, world from {name}!", source=name, ) + topic_id = TopicId("default", self.id.key) - await self.publish_message(BroadcastMessage(response_message)) + await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id) diff --git a/python/teams/team-one/tests/headless_web_surfer/test_web_surfer.py b/python/teams/team-one/tests/headless_web_surfer/test_web_surfer.py index b7d4551c8..3f87670f5 100644 --- a/python/teams/team-one/tests/headless_web_surfer/test_web_surfer.py +++ b/python/teams/team-one/tests/headless_web_surfer/test_web_surfer.py @@ -7,11 +7,14 @@ from math import ceil import asyncio import pytest + +from agnext.core import AgentId +from agnext.core import AgentProxy pytest_plugins = ('pytest_asyncio',) from json import dumps from team_one.utils import ( - ENVIRON_KEY_CHAT_COMPLETION_PROVIDER, + ENVIRON_KEY_CHAT_COMPLETION_PROVIDER, ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON, create_completion_client_from_env ) @@ -96,13 +99,14 @@ async def test_web_surfer() -> None: # Register agents. # Register agents. - web_surfer = await runtime.register_and_get_proxy( + await runtime.register( "WebSurfer", lambda: MultimodalWebSurfer(), ) + web_surfer = AgentId("WebSurfer", "default") run_context = runtime.start() - actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer) + actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer, MultimodalWebSurfer) await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium") # Test some basic navigations @@ -138,7 +142,7 @@ async def test_web_surfer() -> None: tool_resp = await make_browser_request(actual_surfer, TOOL_PAGE_DOWN) assert ( f"The viewport shows {viewport_percentage}% of the webpage, and is positioned at the bottom of the page" in tool_resp - ) + ) # Test Q&A and summarization -- we don't have a key so we expect it to fail #(but it means the code path is correct) with pytest.raises(AuthenticationError): @@ -160,15 +164,17 @@ async def test_web_surfer_oai() -> None: client = create_completion_client_from_env() # Register agents. - web_surfer = await runtime.register_and_get_proxy( + await runtime.register( "WebSurfer", lambda: MultimodalWebSurfer(), ) + web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime) - user_proxy = await runtime.register_and_get_proxy( + await runtime.register( "UserProxy", lambda: UserProxy(), ) + user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime) await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy])) run_context = runtime.start() @@ -220,10 +226,12 @@ async def test_web_surfer_bing() -> None: # Register agents. # Register agents. - web_surfer = await runtime.register_and_get_proxy( + await runtime.register( "WebSurfer", lambda: MultimodalWebSurfer(), ) + web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime) + run_context = runtime.start() actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer) await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium") @@ -235,7 +243,7 @@ async def test_web_surfer_bing() -> None: assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:url"] assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:title"] assert f"I typed '{BING_QUERY}' into the browser search bar." in tool_resp.replace("\\","") - + tool_resp = await make_browser_request(actual_surfer, TOOL_WEB_SEARCH, {"query": BING_QUERY + " Wikipedia"}) markdown = await actual_surfer._get_page_markdown() # type: ignore assert "https://en.wikipedia.org/wiki/" in markdown diff --git a/python/tests/test_cancellation.py b/python/tests/test_cancellation.py index a5ac020bf..75d36d1f2 100644 --- a/python/tests/test_cancellation.py +++ b/python/tests/test_cancellation.py @@ -6,16 +6,18 @@ from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler from agnext.core import AgentId, CancellationToken from agnext.core import MessageContext +from agnext.core import AgentInstantiationContext @dataclass -class MessageType: - ... +class MessageType: ... + # Note for future reader: # To do cancellation, only the token should be interacted with as a user # If you cancel a future, it may not work as you expect. + class LongRunningAgent(TypeRoutedAgent): def __init__(self) -> None: super().__init__("A long running agent") @@ -34,6 +36,7 @@ class LongRunningAgent(TypeRoutedAgent): self.cancelled = True raise + class NestingLongRunningAgent(TypeRoutedAgent): def __init__(self, nested_agent: AgentId) -> None: super().__init__("A nesting long running agent") @@ -58,9 +61,10 @@ class NestingLongRunningAgent(TypeRoutedAgent): async def test_cancellation_with_token() -> None: runtime = SingleThreadedAgentRuntime() - long_running = await runtime.register_and_get("long_running", LongRunningAgent) + await runtime.register("long_running", LongRunningAgent) + agent_id = AgentId("long_running", key="default") token = CancellationToken() - response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token)) + response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token)) assert not response.done() while len(runtime.unprocessed_messages) == 0: @@ -74,21 +78,25 @@ async def test_cancellation_with_token() -> None: await response assert response.done() - long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) + long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled - @pytest.mark.asyncio async def test_nested_cancellation_only_outer_called() -> None: runtime = SingleThreadedAgentRuntime() - long_running = await runtime.register_and_get("long_running", LongRunningAgent) - nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) + await runtime.register("long_running", LongRunningAgent) + await runtime.register( + "nested", + lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)), + ) + long_running_id = AgentId("long_running", key="default") + nested_id = AgentId("nested", key="default") token = CancellationToken() - response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token)) + response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token)) assert not response.done() while len(runtime.unprocessed_messages) == 0: @@ -101,22 +109,29 @@ async def test_nested_cancellation_only_outer_called() -> None: await response assert response.done() - nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent) + nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent) assert nested_agent.called assert nested_agent.cancelled - long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent) assert long_running_agent.called is False assert long_running_agent.cancelled is False + @pytest.mark.asyncio async def test_nested_cancellation_inner_called() -> None: runtime = SingleThreadedAgentRuntime() - long_running = await runtime.register_and_get("long_running", LongRunningAgent ) - nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) + await runtime.register("long_running", LongRunningAgent) + await runtime.register( + "nested", + lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)), + ) + + long_running_id = AgentId("long_running", key="default") + nested_id = AgentId("nested", key="default") token = CancellationToken() - response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token)) + response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token)) assert not response.done() while len(runtime.unprocessed_messages) == 0: @@ -131,9 +146,9 @@ async def test_nested_cancellation_inner_called() -> None: await response assert response.done() - nested_agent = await runtime.try_get_underlying_agent_instance(nested, type=NestingLongRunningAgent) + nested_agent = await runtime.try_get_underlying_agent_instance(nested_id, type=NestingLongRunningAgent) assert nested_agent.called assert nested_agent.cancelled - long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LongRunningAgent) + long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled diff --git a/python/tests/test_closure_agent.py b/python/tests/test_closure_agent.py index ce2f93d13..256b59f55 100644 --- a/python/tests/test_closure_agent.py +++ b/python/tests/test_closure_agent.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest from agnext.application import SingleThreadedAgentRuntime +from agnext.components._type_subscription import TypeSubscription from agnext.core import AgentRuntime, AgentId from agnext.components import ClosureAgent @@ -13,6 +14,7 @@ from agnext.components import ClosureAgent import asyncio from agnext.core import MessageContext +from agnext.core import TopicId @dataclass class Message: @@ -30,11 +32,15 @@ async def test_register_receives_publish() -> None: key = id.key await queue.put((key, message.content)) - await runtime.register("name", lambda: ClosureAgent("My agent", log_message)) + await runtime.register("name", lambda: ClosureAgent("my_agent", log_message)) + await runtime.add_subscription(TypeSubscription("default", "name")) + topic_id = TopicId("default", "default") run_context = runtime.start() - await runtime.publish_message(Message("first message"), namespace="default") - await runtime.publish_message(Message("second message"), namespace="default") - await runtime.publish_message(Message("third message"), namespace="default") + + await runtime.publish_message(Message("first message"), topic_id=topic_id) + await runtime.publish_message(Message("second message"), topic_id=topic_id) + await runtime.publish_message(Message("third message"), topic_id=topic_id) + await run_context.stop_when_idle() diff --git a/python/tests/test_intervention.py b/python/tests/test_intervention.py index 62942bbb1..c566578e2 100644 --- a/python/tests/test_intervention.py +++ b/python/tests/test_intervention.py @@ -19,7 +19,8 @@ async def test_intervention_count_messages() -> None: handler = DebugInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - loopback = await runtime.register_and_get("name", LoopbackAgent) + await runtime.register("name", LoopbackAgent) + loopback = AgentId("name", key="default") run_context = runtime.start() _response = await runtime.send_message(MessageType(), recipient=loopback) @@ -40,7 +41,8 @@ async def test_intervention_drop_send() -> None: handler = DropSendInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - loopback = await runtime.register_and_get("name", LoopbackAgent) + await runtime.register("name", LoopbackAgent) + loopback = AgentId("name", key="default") run_context = runtime.start() with pytest.raises(MessageDroppedException): @@ -62,7 +64,8 @@ async def test_intervention_drop_response() -> None: handler = DropResponseInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - loopback = await runtime.register_and_get("name", LoopbackAgent) + await runtime.register("name", LoopbackAgent) + loopback = AgentId("name", key="default") run_context = runtime.start() with pytest.raises(MessageDroppedException): @@ -84,15 +87,16 @@ async def test_intervention_raise_exception_on_send() -> None: handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - long_running = await runtime.register_and_get("name", LoopbackAgent) + await runtime.register("name", LoopbackAgent) + loopback = AgentId("name", key="default") run_context = runtime.start() with pytest.raises(InterventionException): - _response = await runtime.send_message(MessageType(), recipient=long_running) + _response = await runtime.send_message(MessageType(), recipient=loopback) await run_context.stop() - long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent) + long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert long_running_agent.num_calls == 0 @pytest.mark.asyncio @@ -108,12 +112,13 @@ async def test_intervention_raise_exception_on_respond() -> None: handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - long_running = await runtime.register_and_get("name", LoopbackAgent) + await runtime.register("name", LoopbackAgent) + loopback = AgentId("name", key="default") run_context = runtime.start() with pytest.raises(InterventionException): - _response = await runtime.send_message(MessageType(), recipient=long_running) + _response = await runtime.send_message(MessageType(), recipient=loopback) await run_context.stop() - long_running_agent = await runtime.try_get_underlying_agent_instance(long_running, type=LoopbackAgent) + long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert long_running_agent.num_calls == 1 diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index 7344847f1..ccdb25d3d 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -1,6 +1,8 @@ import pytest from agnext.application import SingleThreadedAgentRuntime +from agnext.components._type_subscription import TypeSubscription from agnext.core import AgentId, AgentInstantiationContext +from agnext.core import TopicId from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent @@ -15,13 +17,12 @@ async def test_agent_names_must_be_unique() -> None: assert agent.id == id return agent - agent1 = await runtime.register_and_get("name1", agent_factory) - assert agent1 == AgentId("name1", "default") + await runtime.register("name1", agent_factory) with pytest.raises(ValueError): - _agent1 = await runtime.register_and_get("name1", NoopAgent) + await runtime.register("name1", NoopAgent) - _agent1 = await runtime.register_and_get("name3", NoopAgent) + await runtime.register("name3", NoopAgent) @pytest.mark.asyncio @@ -30,16 +31,19 @@ async def test_register_receives_publish() -> None: await runtime.register("name", LoopbackAgent) run_context = runtime.start() - await runtime.publish_message(MessageType(), namespace="default") + await runtime.add_subscription(TypeSubscription("default", "name")) + agent_id = AgentId("name", key="default") + topic_id = TopicId("default", "default") + await runtime.publish_message(MessageType(), topic_id=topic_id) await run_context.stop_when_idle() # Agent in default namespace should have received the message - long_running_agent = await runtime.try_get_underlying_agent_instance(await runtime.get("name"), type=LoopbackAgent) + long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent) assert long_running_agent.num_calls == 1 # Agent in other namespace should not have received the message - other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(await runtime.get("name", namespace="other"), type=LoopbackAgent) + other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent) assert other_long_running_agent.num_calls == 0 @@ -56,17 +60,19 @@ async def test_register_receives_publish_cascade() -> None: # Register agents for i in range(num_agents): await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds)) + await runtime.add_subscription(TypeSubscription("default", f"name{i}")) run_context = runtime.start() # Publish messages + topic_id = TopicId("default", "default") for _ in range(num_initial_messages): - await runtime.publish_message(CascadingMessageType(round=1), namespace="default") + await runtime.publish_message(CascadingMessageType(round=1), topic_id) # Process until idle. await run_context.stop_when_idle() # Check that each agent received the correct number of messages. for i in range(num_agents): - agent = await runtime.try_get_underlying_agent_instance(await runtime.get(f"name{i}"), CascadingAgent) + agent = await runtime.try_get_underlying_agent_instance(AgentId(f"name{i}", "default"), CascadingAgent) assert agent.num_calls == total_num_calls_expected diff --git a/python/tests/test_serialization.py b/python/tests/test_serialization.py index 7ffd96388..dacd272bc 100644 --- a/python/tests/test_serialization.py +++ b/python/tests/test_serialization.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import pytest -from agnext.core._serialization import Serialization +from agnext.core import Serialization class PydanticMessage(BaseModel): message: str diff --git a/python/tests/test_state.py b/python/tests/test_state.py index 8bc562d1c..33fc8d67a 100644 --- a/python/tests/test_state.py +++ b/python/tests/test_state.py @@ -1,19 +1,16 @@ -from typing import Any, Mapping, Sequence +from typing import Any, Mapping import pytest from agnext.application import SingleThreadedAgentRuntime from agnext.core import BaseAgent, MessageContext +from agnext.core import AgentId class StatefulAgent(BaseAgent): def __init__(self) -> None: - super().__init__("A stateful agent", []) + super().__init__("A stateful agent") self.state = 0 - @property - def subscriptions(self) -> Sequence[type]: - return [] - async def on_message(self, message: Any, ctx: MessageContext) -> None: raise NotImplementedError @@ -28,7 +25,8 @@ class StatefulAgent(BaseAgent): async def test_agent_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - agent1_id = await runtime.register_and_get("name1", StatefulAgent) + await runtime.register("name1", StatefulAgent) + agent1_id = AgentId("name1", key="default") agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent) assert agent1.state == 0 agent1.state = 1 @@ -46,7 +44,8 @@ async def test_agent_can_save_state() -> None: async def test_runtime_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - agent1_id = await runtime.register_and_get("name1", StatefulAgent) + await runtime.register("name1", StatefulAgent) + agent1_id = AgentId("name1", key="default") agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent) assert agent1.state == 0 agent1.state = 1 @@ -55,7 +54,8 @@ async def test_runtime_can_save_state() -> None: runtime_state = await runtime.save_state() runtime2 = SingleThreadedAgentRuntime() - agent2_id = await runtime2.register_and_get("name1", StatefulAgent) + await runtime2.register("name1", StatefulAgent) + agent2_id = AgentId("name1", key="default") agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent) await runtime2.load_state(runtime_state) diff --git a/python/tests/test_tool_agent.py b/python/tests/test_tool_agent.py index f2dd5423f..acc66e1fa 100644 --- a/python/tests/test_tool_agent.py +++ b/python/tests/test_tool_agent.py @@ -13,6 +13,7 @@ from agnext.components.tool_agent import ( ) from agnext.components.tools import FunctionTool from agnext.core import CancellationToken +from agnext.core import AgentId def _pass_function(input: str) -> str: @@ -31,7 +32,7 @@ async def _async_sleep_function(input: str) -> str: @pytest.mark.asyncio async def test_tool_agent() -> None: runtime = SingleThreadedAgentRuntime() - agent = await runtime.register_and_get( + await runtime.register( "tool_agent", lambda: ToolAgent( description="Tool agent", @@ -42,6 +43,7 @@ async def test_tool_agent() -> None: ], ), ) + agent = AgentId("tool_agent", "default") run = runtime.start() # Test pass function diff --git a/python/tests/test_utils/__init__.py b/python/tests/test_utils/__init__.py index 9ce86897f..5a056a891 100644 --- a/python/tests/test_utils/__init__.py +++ b/python/tests/test_utils/__init__.py @@ -38,11 +38,12 @@ class CascadingAgent(TypeRoutedAgent): self.num_calls += 1 if message.round == self.max_rounds: return - await self.publish_message(CascadingMessageType(round=message.round + 1)) + assert ctx.topic_id is not None + await self.publish_message(CascadingMessageType(round=message.round + 1), topic_id=ctx.topic_id) class NoopAgent(BaseAgent): def __init__(self) -> None: - super().__init__("A no op agent", []) + super().__init__("A no op agent") async def on_message(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError