From 853b00b0f058aa0afdefaa5d3855c9b68c319929 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 16 Aug 2024 23:14:09 -0400 Subject: [PATCH] Add message context to message handler (#367) Co-authored-by: Eric Zhu --- .../docs/src/cookbook/langgraph-agent.ipynb | 4 +- .../docs/src/cookbook/llamaindex-agent.ipynb | 4 +- .../src/cookbook/openai-assistant-agent.ipynb | 40 +++++++++---------- .../agent-and-agent-runtime.ipynb | 4 +- .../message-and-communication.ipynb | 18 ++++----- .../src/getting-started/model-clients.ipynb | 6 +-- .../multi-agent-design-patterns.ipynb | 14 +++---- python/docs/src/getting-started/tools.ipynb | 4 +- python/samples/byoa/langgraph_agent.py | 4 +- python/samples/byoa/llamaindex_agent.py | 4 +- .../common/agents/_chat_completion_agent.py | 26 ++++++------ .../common/agents/_image_generation_agent.py | 11 ++--- .../samples/common/agents/_oai_assistant.py | 19 +++++---- python/samples/common/agents/_user_proxy.py | 4 +- .../common/patterns/_group_chat_manager.py | 8 ++-- .../common/patterns/_orchestrator_chat.py | 4 +- python/samples/core/inner_outer_direct.py | 6 +-- python/samples/core/one_agent_direct.py | 4 +- python/samples/core/two_agents_pub_sub.py | 4 +- python/samples/demos/chat_room.py | 7 ++-- python/samples/marketing-agents/auditor.py | 4 +- .../marketing-agents/graphic_designer.py | 4 +- python/samples/marketing-agents/test_usage.py | 6 +-- python/samples/patterns/coder_executor.py | 25 ++++++------ python/samples/patterns/coder_reviewer.py | 8 ++-- python/samples/patterns/group_chat.py | 9 +++-- python/samples/patterns/mixture_of_agents.py | 8 ++-- python/samples/patterns/multi_agent_debate.py | 12 +++--- python/samples/tool-use/coding_direct.py | 7 ++-- python/samples/tool-use/coding_pub_sub.py | 10 ++--- .../_single_threaded_agent_runtime.py | 20 ++++++++-- .../src/agnext/components/_closure_agent.py | 11 ++--- .../agnext/components/_type_routed_agent.py | 28 ++++++------- .../components/tool_agent/_tool_agent.py | 8 ++-- python/src/agnext/core/__init__.py | 2 + python/src/agnext/core/_agent.py | 6 +-- python/src/agnext/core/_base_agent.py | 3 +- python/src/agnext/core/_message_context.py | 13 ++++++ python/src/agnext/worker/worker_runtime.py | 12 +++++- .../src/team_one/agents/base_agent.py | 36 ++++++++--------- .../src/team_one/agents/base_orchestrator.py | 8 ++-- .../src/team_one/agents/base_worker.py | 12 +++--- .../multimodal_web_surfer.py | 7 +--- .../src/team_one/agents/reflex_agents.py | 10 ++--- python/tests/test_cancellation.py | 9 +++-- python/tests/test_closure_agent.py | 6 ++- python/tests/test_state.py | 4 +- python/tests/test_types.py | 8 ++-- python/tests/test_utils/__init__.py | 11 ++--- 49 files changed, 267 insertions(+), 235 deletions(-) create mode 100644 python/src/agnext/core/_message_context.py diff --git a/python/docs/src/cookbook/langgraph-agent.ipynb b/python/docs/src/cookbook/langgraph-agent.ipynb index 84d9bc83a..27280a57b 100644 --- a/python/docs/src/cookbook/langgraph-agent.ipynb +++ b/python/docs/src/cookbook/langgraph-agent.ipynb @@ -45,7 +45,7 @@ "\n", "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import CancellationToken\n", + "from agnext.core import MessageContext\n", "from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n", "from langchain_core.messages import HumanMessage, SystemMessage\n", "from langchain_core.tools import tool # pyright: ignore\n", @@ -162,7 +162,7 @@ " self._app = self._workflow.compile()\n", "\n", " @message_handler\n", - " async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n", + " async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n", " # Use the Runnable\n", " final_state = await self._app.ainvoke(\n", " {\n", diff --git a/python/docs/src/cookbook/llamaindex-agent.ipynb b/python/docs/src/cookbook/llamaindex-agent.ipynb index 5c2665221..d14dc261b 100644 --- a/python/docs/src/cookbook/llamaindex-agent.ipynb +++ b/python/docs/src/cookbook/llamaindex-agent.ipynb @@ -44,7 +44,7 @@ "\n", "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import CancellationToken\n", + "from agnext.core import MessageContext\n", "from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n", "from llama_index.core import Settings\n", "from llama_index.core.agent import ReActAgent\n", @@ -110,7 +110,7 @@ " self._memory = memory\n", "\n", " @message_handler\n", - " async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n", + " async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n", " # retriever history messages from memory!\n", " history_messages: List[ChatMessage] = []\n", "\n", diff --git a/python/docs/src/cookbook/openai-assistant-agent.ipynb b/python/docs/src/cookbook/openai-assistant-agent.ipynb index ea10868fb..feadcec11 100644 --- a/python/docs/src/cookbook/openai-assistant-agent.ipynb +++ b/python/docs/src/cookbook/openai-assistant-agent.ipynb @@ -105,7 +105,7 @@ "\n", "import aiofiles\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import CancellationToken\n", + "from agnext.core import MessageContext\n", "from openai import AsyncAssistantEventHandler, AsyncClient\n", "from openai.types.beta.thread import ToolResources, ToolResourcesFileSearch\n", "\n", @@ -140,10 +140,10 @@ " self._assistant_event_handler_factory = assistant_event_handler_factory\n", "\n", " @message_handler\n", - " async def handle_message(self, message: TextMessage, cancellation_token: CancellationToken) -> TextMessage:\n", + " async def handle_message(self, message: TextMessage, ctx: MessageContext) -> TextMessage:\n", " \"\"\"Handle a message. This method adds the message to the thread and publishes a response.\"\"\"\n", " # Save the message to the thread.\n", - " await cancellation_token.link_future(\n", + " await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(\n", " self._client.beta.threads.messages.create(\n", " thread_id=self._thread_id,\n", @@ -159,10 +159,10 @@ " assistant_id=self._assistant_id,\n", " event_handler=self._assistant_event_handler_factory(),\n", " ) as stream:\n", - " await cancellation_token.link_future(asyncio.ensure_future(stream.until_done()))\n", + " await ctx.cancellation_token.link_future(asyncio.ensure_future(stream.until_done()))\n", "\n", " # Get the last message.\n", - " messages = await cancellation_token.link_future(\n", + " messages = await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, order=\"desc\", limit=1))\n", " )\n", " last_message_content = messages.data[0].content\n", @@ -175,17 +175,17 @@ " return TextMessage(content=text_content[0].text.value, source=self.metadata[\"type\"])\n", "\n", " @message_handler()\n", - " async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:\n", + " async def on_reset(self, message: Reset, ctx: MessageContext) -> None:\n", " \"\"\"Handle a reset message. This method deletes all messages in the thread.\"\"\"\n", " # Get all messages in this thread.\n", " all_msgs: List[str] = []\n", " while True:\n", " if not all_msgs:\n", - " msgs = await cancellation_token.link_future(\n", + " msgs = await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id))\n", " )\n", " else:\n", - " msgs = await cancellation_token.link_future(\n", + " msgs = await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1]))\n", " )\n", " for msg in msgs.data:\n", @@ -194,7 +194,7 @@ " break\n", " # Delete all the messages.\n", " for msg_id in all_msgs:\n", - " status = await cancellation_token.link_future(\n", + " status = await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(\n", " self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)\n", " )\n", @@ -202,20 +202,18 @@ " assert status.deleted is True\n", "\n", " @message_handler()\n", - " async def on_upload_for_code_interpreter(\n", - " self, message: UploadForCodeInterpreter, cancellation_token: CancellationToken\n", - " ) -> None:\n", + " async def on_upload_for_code_interpreter(self, message: UploadForCodeInterpreter, ctx: MessageContext) -> None:\n", " \"\"\"Handle an upload for code interpreter. This method uploads a file and updates the thread with the file.\"\"\"\n", " # Get the file content.\n", " async with aiofiles.open(message.file_path, mode=\"rb\") as f:\n", - " file_content = await cancellation_token.link_future(asyncio.ensure_future(f.read()))\n", + " file_content = await ctx.cancellation_token.link_future(asyncio.ensure_future(f.read()))\n", " file_name = os.path.basename(message.file_path)\n", " # Upload the file.\n", - " file = await cancellation_token.link_future(\n", + " file = await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(self._client.files.create(file=(file_name, file_content), purpose=\"assistants\"))\n", " )\n", " # Get existing file ids from tool resources.\n", - " thread = await cancellation_token.link_future(\n", + " thread = await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(self._client.beta.threads.retrieve(thread_id=self._thread_id))\n", " )\n", " tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources()\n", @@ -225,7 +223,7 @@ " else:\n", " file_ids = [file.id]\n", " # Update thread with new file.\n", - " await cancellation_token.link_future(\n", + " await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(\n", " self._client.beta.threads.update(\n", " thread_id=self._thread_id,\n", @@ -237,16 +235,14 @@ " )\n", "\n", " @message_handler()\n", - " async def on_upload_for_file_search(\n", - " self, message: UploadForFileSearch, cancellation_token: CancellationToken\n", - " ) -> None:\n", + " async def on_upload_for_file_search(self, message: UploadForFileSearch, ctx: MessageContext) -> None:\n", " \"\"\"Handle an upload for file search. This method uploads a file and updates the vector store.\"\"\"\n", " # Get the file content.\n", " async with aiofiles.open(message.file_path, mode=\"rb\") as file:\n", - " file_content = await cancellation_token.link_future(asyncio.ensure_future(file.read()))\n", + " file_content = await ctx.cancellation_token.link_future(asyncio.ensure_future(file.read()))\n", " file_name = os.path.basename(message.file_path)\n", " # Upload the file.\n", - " await cancellation_token.link_future(\n", + " await ctx.cancellation_token.link_future(\n", " asyncio.ensure_future(\n", " self._client.beta.vector_stores.file_batches.upload_and_poll(\n", " vector_store_id=message.vector_store_id,\n", @@ -837,7 +833,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/python/docs/src/getting-started/agent-and-agent-runtime.ipynb b/python/docs/src/getting-started/agent-and-agent-runtime.ipynb index 009b7c362..4717d4804 100644 --- a/python/docs/src/getting-started/agent-and-agent-runtime.ipynb +++ b/python/docs/src/getting-started/agent-and-agent-runtime.ipynb @@ -57,7 +57,7 @@ "source": [ "from dataclasses import dataclass\n", "\n", - "from agnext.core import BaseAgent, CancellationToken\n", + "from agnext.core import BaseAgent, MessageContext\n", "\n", "\n", "@dataclass\n", @@ -69,7 +69,7 @@ " def __init__(self) -> None:\n", " super().__init__(\"MyAgent\", subscriptions=[\"MyMessage\"])\n", "\n", - " async def on_message(self, message: MyMessage, cancellation_token: CancellationToken) -> None:\n", + " async def on_message(self, message: MyMessage, ctx: MessageContext) -> None:\n", " print(f\"Received message: {message.content}\")" ] }, diff --git a/python/docs/src/getting-started/message-and-communication.ipynb b/python/docs/src/getting-started/message-and-communication.ipynb index b0e1838d9..a9ae906cd 100644 --- a/python/docs/src/getting-started/message-and-communication.ipynb +++ b/python/docs/src/getting-started/message-and-communication.ipynb @@ -83,16 +83,16 @@ "source": [ "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import CancellationToken\n", + "from agnext.core import MessageContext\n", "\n", "\n", "class MyAgent(TypeRoutedAgent):\n", " @message_handler\n", - " async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:\n", + " async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None:\n", " print(f\"Hello, {message.source}, you said {message.content}!\")\n", "\n", " @message_handler\n", - " async def on_image_message(self, message: ImageMessage, cancellation_token: CancellationToken) -> None:\n", + " async def on_image_message(self, message: ImageMessage, ctx: MessageContext) -> None:\n", " print(f\"Hello, {message.source}, you sent me {message.url}!\")" ] }, @@ -185,7 +185,7 @@ "\n", "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import AgentId, CancellationToken\n", + "from agnext.core import AgentId, MessageContext\n", "\n", "\n", "@dataclass\n", @@ -195,7 +195,7 @@ "\n", "class InnerAgent(TypeRoutedAgent):\n", " @message_handler\n", - " async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n", + " async def on_my_message(self, message: Message, ctx: MessageContext) -> Message:\n", " return Message(content=f\"Hello from inner, {message.content}\")\n", "\n", "\n", @@ -205,7 +205,7 @@ " self.inner_agent_id = inner_agent_id\n", "\n", " @message_handler\n", - " async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> None:\n", + " async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n", " print(f\"Received message: {message.content}\")\n", " # Send a direct message to the inner agent and receves a response.\n", " response = await self.send_message(Message(f\"Hello from outer, {message.content}\"), self.inner_agent_id)\n", @@ -294,19 +294,19 @@ "source": [ "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", - "from agnext.core import CancellationToken\n", + "from agnext.core import MessageContext\n", "\n", "\n", "class BroadcastingAgent(TypeRoutedAgent):\n", " @message_handler\n", - " async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> None:\n", + " async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n", " # Publish a message to all agents in the same namespace.\n", " await self.publish_message(Message(f\"Publishing a message: {message.content}!\"))\n", "\n", "\n", "class ReceivingAgent(TypeRoutedAgent):\n", " @message_handler\n", - " async def on_my_message(self, message: Message, cancellation_token: CancellationToken) -> None:\n", + " async def on_my_message(self, message: Message, ctx: MessageContext) -> None:\n", " print(f\"Received a message: {message.content}\")" ] }, diff --git a/python/docs/src/getting-started/model-clients.ipynb b/python/docs/src/getting-started/model-clients.ipynb index 7750f1d85..048821044 100644 --- a/python/docs/src/getting-started/model-clients.ipynb +++ b/python/docs/src/getting-started/model-clients.ipynb @@ -229,7 +229,7 @@ "from agnext.application import SingleThreadedAgentRuntime\n", "from agnext.components import TypeRoutedAgent, message_handler\n", "from agnext.components.models import ChatCompletionClient, OpenAIChatCompletionClient, SystemMessage, UserMessage\n", - "from agnext.core import CancellationToken\n", + "from agnext.core import MessageContext\n", "\n", "\n", "@dataclass\n", @@ -244,11 +244,11 @@ " self._model_client = model_client\n", "\n", " @message_handler\n", - " async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n", + " async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n", " # Prepare input to the chat completion model.\n", " user_message = UserMessage(content=message.content, source=\"user\")\n", " response = await self._model_client.create(\n", - " self._system_messages + [user_message], cancellation_token=cancellation_token\n", + " self._system_messages + [user_message], cancellation_token=ctx.cancellation_token\n", " )\n", " # Return with the model's response.\n", " assert isinstance(response.content, str)\n", diff --git a/python/docs/src/getting-started/multi-agent-design-patterns.ipynb b/python/docs/src/getting-started/multi-agent-design-patterns.ipynb index 44cd54c2c..8cd4a39eb 100644 --- a/python/docs/src/getting-started/multi-agent-design-patterns.ipynb +++ b/python/docs/src/getting-started/multi-agent-design-patterns.ipynb @@ -131,7 +131,7 @@ " SystemMessage,\n", " UserMessage,\n", ")\n", - "from agnext.core import CancellationToken" + "from agnext.core import MessageContext" ] }, { @@ -177,14 +177,14 @@ " self._session_memory: Dict[str, List[CodeWritingTask | CodeReviewTask | CodeReviewResult]] = {}\n", "\n", " @message_handler\n", - " async def handle_code_writing_task(self, message: CodeWritingTask, cancellation_token: CancellationToken) -> None:\n", + " async def handle_code_writing_task(self, message: CodeWritingTask, ctx: MessageContext) -> None:\n", " # Store the messages in a temporary memory for this request only.\n", " session_id = str(uuid.uuid4())\n", " self._session_memory.setdefault(session_id, []).append(message)\n", " # Generate a response using the chat completion API.\n", " response = await self._model_client.create(\n", " self._system_messages + [UserMessage(content=message.task, source=self.metadata[\"type\"])],\n", - " cancellation_token=cancellation_token,\n", + " cancellation_token=ctx.cancellation_token,\n", " )\n", " assert isinstance(response.content, str)\n", " # Extract the code block from the response.\n", @@ -204,7 +204,7 @@ " await self.publish_message(code_review_task)\n", "\n", " @message_handler\n", - " async def handle_code_review_result(self, message: CodeReviewResult, cancellation_token: CancellationToken) -> None:\n", + " async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None:\n", " # Store the review result in the session memory.\n", " self._session_memory[message.session_id].append(message)\n", " # Obtain the request from previous messages.\n", @@ -243,7 +243,7 @@ " else:\n", " raise ValueError(f\"Unexpected message type: {m}\")\n", " # Generate a revision using the chat completion API.\n", - " response = await self._model_client.create(messages, cancellation_token=cancellation_token)\n", + " response = await self._model_client.create(messages, cancellation_token=ctx.cancellation_token)\n", " assert isinstance(response.content, str)\n", " # Extract the code block from the response.\n", " code_block = self._extract_code_block(response.content)\n", @@ -315,7 +315,7 @@ " self._model_client = model_client\n", "\n", " @message_handler\n", - " async def handle_code_review_task(self, message: CodeReviewTask, cancellation_token: CancellationToken) -> None:\n", + " async def handle_code_review_task(self, message: CodeReviewTask, ctx: MessageContext) -> None:\n", " # Format the prompt for the code review.\n", " # Gather the previous feedback if available.\n", " previous_feedback = \"\"\n", @@ -342,7 +342,7 @@ " # Generate a response using the chat completion API.\n", " response = await self._model_client.create(\n", " self._system_messages + [UserMessage(content=prompt, source=self.metadata[\"type\"])],\n", - " cancellation_token=cancellation_token,\n", + " cancellation_token=ctx.cancellation_token,\n", " json_output=True,\n", " )\n", " assert isinstance(response.content, str)\n", diff --git a/python/docs/src/getting-started/tools.ipynb b/python/docs/src/getting-started/tools.ipynb index 65bfb192b..7a3ddc2bb 100644 --- a/python/docs/src/getting-started/tools.ipynb +++ b/python/docs/src/getting-started/tools.ipynb @@ -148,7 +148,7 @@ ")\n", "from agnext.components.tool_agent import ToolAgent, ToolException\n", "from agnext.components.tools import FunctionTool, Tool, ToolSchema\n", - "from agnext.core import AgentId, CancellationToken\n", + "from agnext.core import AgentId, MessageContext\n", "\n", "\n", "@dataclass\n", @@ -165,7 +165,7 @@ " self._tool_agent = tool_agent\n", "\n", " @message_handler\n", - " async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message:\n", + " async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n", " # Create a session of messages.\n", " session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n", " # Get a response from the model.\n", diff --git a/python/samples/byoa/langgraph_agent.py b/python/samples/byoa/langgraph_agent.py index 490d37148..41c3ee43c 100644 --- a/python/samples/byoa/langgraph_agent.py +++ b/python/samples/byoa/langgraph_agent.py @@ -10,7 +10,7 @@ from typing import Any, Callable, List, Literal from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import MessageContext from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool # pyright: ignore from langchain_openai import ChatOpenAI @@ -89,7 +89,7 @@ class LangGraphToolUseAgent(TypeRoutedAgent): self._app = self._workflow.compile() @message_handler - async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message: + async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message: # Use the Runnable final_state = await self._app.ainvoke( { diff --git a/python/samples/byoa/llamaindex_agent.py b/python/samples/byoa/llamaindex_agent.py index f3918440b..72d06e7d5 100644 --- a/python/samples/byoa/llamaindex_agent.py +++ b/python/samples/byoa/llamaindex_agent.py @@ -9,7 +9,7 @@ from typing import List, Optional from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import MessageContext from llama_index.core import Settings from llama_index.core.agent import ReActAgent from llama_index.core.agent.runner.base import AgentRunner @@ -46,7 +46,7 @@ class LlamaIndexAgent(TypeRoutedAgent): self._memory = memory @message_handler - async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message: + async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message: # retriever history messages from memory! history_messages: List[ChatMessage] = [] diff --git a/python/samples/common/agents/_chat_completion_agent.py b/python/samples/common/agents/_chat_completion_agent.py index a3a921440..bd48aae5b 100644 --- a/python/samples/common/agents/_chat_completion_agent.py +++ b/python/samples/common/agents/_chat_completion_agent.py @@ -15,7 +15,7 @@ from agnext.components.models import ( SystemMessage, ) from agnext.components.tools import Tool -from agnext.core import AgentId, CancellationToken +from agnext.core import AgentId, CancellationToken, MessageContext from ..types import ( FunctionCallMessage, @@ -74,50 +74,48 @@ class ChatCompletionAgent(TypeRoutedAgent): self._tool_approver = tool_approver @message_handler() - async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: + async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None: """Handle a text message. This method adds the message to the memory and does not generate any message.""" # Add a user message. await self._memory.add_message(message) @message_handler() - async def on_multi_modal_message(self, message: MultiModalMessage, cancellation_token: CancellationToken) -> None: + async def on_multi_modal_message(self, message: MultiModalMessage, ctx: MessageContext) -> None: """Handle a multimodal message. This method adds the message to the memory and does not generate any message.""" # Add a user message. await self._memory.add_message(message) @message_handler() - async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: + async def on_reset(self, message: Reset, ctx: MessageContext) -> None: """Handle a reset message. This method clears the memory.""" # Reset the chat messages. await self._memory.clear() @message_handler() - async def on_respond_now( - self, message: RespondNow, cancellation_token: CancellationToken - ) -> TextMessage | FunctionCallMessage: + async def on_respond_now(self, message: RespondNow, ctx: MessageContext) -> TextMessage | FunctionCallMessage: """Handle a respond now message. This method generates a response and returns it to the sender.""" # Generate a response. - response = await self._generate_response(message.response_format, cancellation_token) + response = await self._generate_response(message.response_format, ctx) # Return the response. return response @message_handler() - async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: + async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: """Handle a publish now message. This method generates a response and publishes it.""" # Generate a response. - response = await self._generate_response(message.response_format, cancellation_token) + response = await self._generate_response(message.response_format, ctx) # Publish the response. await self.publish_message(response) @message_handler() async def on_tool_call_message( - self, message: FunctionCallMessage, cancellation_token: CancellationToken + self, message: FunctionCallMessage, ctx: MessageContext ) -> FunctionExecutionResultMessage: """Handle a tool call message. This method executes the tools and returns the results.""" @@ -147,7 +145,7 @@ class ChatCompletionAgent(TypeRoutedAgent): function_call.name, arguments, function_call.id, - cancellation_token=cancellation_token, + cancellation_token=ctx.cancellation_token, ) # Append the async result. execution_futures.append(future) @@ -170,7 +168,7 @@ class ChatCompletionAgent(TypeRoutedAgent): async def _generate_response( self, response_format: ResponseFormat, - cancellation_token: CancellationToken, + ctx: MessageContext, ) -> TextMessage | FunctionCallMessage: # Get a response from the model. hisorical_messages = await self._memory.get_messages() @@ -192,7 +190,7 @@ class ChatCompletionAgent(TypeRoutedAgent): response = await self.send_message( message=FunctionCallMessage(content=response.content, source=self.metadata["type"]), recipient=self.id, - cancellation_token=cancellation_token, + cancellation_token=ctx.cancellation_token, ) # Make an assistant message from the response. hisorical_messages = await self._memory.get_messages() diff --git a/python/samples/common/agents/_image_generation_agent.py b/python/samples/common/agents/_image_generation_agent.py index b9a08ade0..31650e1c7 100644 --- a/python/samples/common/agents/_image_generation_agent.py +++ b/python/samples/common/agents/_image_generation_agent.py @@ -7,7 +7,8 @@ from agnext.components import ( message_handler, ) from agnext.components.memory import ChatMemory -from agnext.core import CancellationToken +from agnext.core import MessageContext +from agnext.core._cancellation_token import CancellationToken from ..types import ( Message, @@ -42,21 +43,21 @@ class ImageGenerationAgent(TypeRoutedAgent): self._memory = memory @message_handler - async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: + async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None: """Handle a text message. This method adds the message to the memory.""" await self._memory.add_message(message) @message_handler - async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: + async def on_reset(self, message: Reset, ctx: MessageContext) -> None: await self._memory.clear() @message_handler - async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: + async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: """Handle a publish now message. This method generates an image using a DALL-E model with a prompt. The prompt is a concatenation of all TextMessages in the memory. The generated image is published as a MultiModalMessage.""" - response = await self._generate_response(cancellation_token) + response = await self._generate_response(ctx.cancellation_token) await self.publish_message(response) async def _generate_response(self, cancellation_token: CancellationToken) -> MultiModalMessage: diff --git a/python/samples/common/agents/_oai_assistant.py b/python/samples/common/agents/_oai_assistant.py index ae0833917..75a46d077 100644 --- a/python/samples/common/agents/_oai_assistant.py +++ b/python/samples/common/agents/_oai_assistant.py @@ -2,9 +2,12 @@ from typing import Any, Callable, List, Mapping import openai from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import ( + CancellationToken, + MessageContext, # type: ignore +) from openai import AsyncAssistantEventHandler -from openai.types import ResponseFormatJSONObject, ResponseFormatText # type: ignore +from openai.types import ResponseFormatJSONObject, ResponseFormatText from ..types import PublishNow, Reset, RespondNow, ResponseFormat, TextMessage @@ -39,7 +42,7 @@ class OpenAIAssistantAgent(TypeRoutedAgent): self._assistant_event_handler_factory = assistant_event_handler_factory @message_handler() - async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: + async def on_text_message(self, message: TextMessage, ctx: MessageContext) -> None: """Handle a text message. This method adds the message to the thread.""" # Save the message to the thread. _ = await self._client.beta.threads.messages.create( @@ -50,7 +53,7 @@ class OpenAIAssistantAgent(TypeRoutedAgent): ) @message_handler() - async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: + async def on_reset(self, message: Reset, ctx: MessageContext) -> None: """Handle a reset message. This method deletes all messages in the thread.""" # Get all messages in this thread. all_msgs: List[str] = [] @@ -69,14 +72,14 @@ class OpenAIAssistantAgent(TypeRoutedAgent): assert status.deleted is True @message_handler() - async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: + async def on_respond_now(self, message: RespondNow, ctx: MessageContext) -> TextMessage: """Handle a respond now message. This method generates a response and returns it to the sender.""" - return await self._generate_response(message.response_format, cancellation_token) + return await self._generate_response(message.response_format, ctx.cancellation_token) @message_handler() - async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: + async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: """Handle a publish now message. This method generates a response and publishes it.""" - response = await self._generate_response(message.response_format, cancellation_token) + response = await self._generate_response(message.response_format, ctx.cancellation_token) await self.publish_message(response) async def _generate_response( diff --git a/python/samples/common/agents/_user_proxy.py b/python/samples/common/agents/_user_proxy.py index 048dac1b6..723490533 100644 --- a/python/samples/common/agents/_user_proxy.py +++ b/python/samples/common/agents/_user_proxy.py @@ -1,7 +1,7 @@ import asyncio from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import MessageContext from ..types import PublishNow, TextMessage @@ -20,7 +20,7 @@ class UserProxyAgent(TypeRoutedAgent): self._user_input_prompt = user_input_prompt @message_handler() - async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None: + async def on_publish_now(self, message: PublishNow, ctx: MessageContext) -> None: """Handle a publish now message. This method prompts the user for input, then publishes it.""" user_input = await self.get_user_input(self._user_input_prompt) await self.publish_message(TextMessage(content=user_input, source=self.metadata["type"])) diff --git a/python/samples/common/patterns/_group_chat_manager.py b/python/samples/common/patterns/_group_chat_manager.py index e119353be..ad0760ea7 100644 --- a/python/samples/common/patterns/_group_chat_manager.py +++ b/python/samples/common/patterns/_group_chat_manager.py @@ -4,7 +4,7 @@ from typing import Any, Callable, List, Mapping from agnext.components import TypeRoutedAgent, message_handler from agnext.components.memory import ChatMemory from agnext.components.models import ChatCompletionClient -from agnext.core import AgentId, AgentProxy, CancellationToken +from agnext.core import AgentId, AgentProxy, MessageContext from ..types import ( Message, @@ -76,14 +76,12 @@ class GroupChatManager(TypeRoutedAgent): self._on_message_received = on_message_received @message_handler() - async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: + async def on_reset(self, message: Reset, ctx: MessageContext) -> None: """Handle a reset message. This method clears the memory.""" await self._memory.clear() @message_handler() - async def on_new_message( - self, message: TextMessage | MultiModalMessage, cancellation_token: CancellationToken - ) -> None: + async def on_new_message(self, message: TextMessage | MultiModalMessage, ctx: MessageContext) -> None: """Handle a message. This method adds the message to the memory, selects the next speaker, and sends a message to the selected speaker to publish a response.""" # Call the custom on_message_received handler if provided. diff --git a/python/samples/common/patterns/_orchestrator_chat.py b/python/samples/common/patterns/_orchestrator_chat.py index a6d8b1772..accad556d 100644 --- a/python/samples/common/patterns/_orchestrator_chat.py +++ b/python/samples/common/patterns/_orchestrator_chat.py @@ -2,7 +2,7 @@ import json from typing import Any, Sequence, Tuple from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentId, AgentRuntime, CancellationToken +from agnext.core import AgentId, AgentRuntime, MessageContext from ..types import Reset, RespondNow, ResponseFormat, TextMessage @@ -37,7 +37,7 @@ class OrchestratorChat(TypeRoutedAgent): async def on_text_message( self, message: TextMessage, - cancellation_token: CancellationToken, + ctx: MessageContext, ) -> TextMessage: # A task is received. task = message.content diff --git a/python/samples/core/inner_outer_direct.py b/python/samples/core/inner_outer_direct.py index 5196df0e2..cc4fe3c53 100644 --- a/python/samples/core/inner_outer_direct.py +++ b/python/samples/core/inner_outer_direct.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentId, CancellationToken +from agnext.core import AgentId, MessageContext @dataclass @@ -26,7 +26,7 @@ class Inner(TypeRoutedAgent): super().__init__("The inner agent") @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: return MessageType(body=f"Inner: {message.body}", sender=self.metadata["type"]) @@ -36,7 +36,7 @@ class Outer(TypeRoutedAgent): self._inner = inner @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: inner_response = self.send_message(message, self._inner) inner_message = await inner_response assert isinstance(inner_message, MessageType) diff --git a/python/samples/core/one_agent_direct.py b/python/samples/core/one_agent_direct.py index b7b7b9a56..6319610ba 100644 --- a/python/samples/core/one_agent_direct.py +++ b/python/samples/core/one_agent_direct.py @@ -17,10 +17,10 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import CancellationToken sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -36,7 +36,7 @@ class ChatCompletionAgent(TypeRoutedAgent): self._model_client = model_client @message_handler - async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message: + async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message: user_message = UserMessage(content=message.content, source="User") response = await self._model_client.create(self._system_messages + [user_message]) assert isinstance(response.content, str) diff --git a/python/samples/core/two_agents_pub_sub.py b/python/samples/core/two_agents_pub_sub.py index b715b6361..8532781ad 100644 --- a/python/samples/core/two_agents_pub_sub.py +++ b/python/samples/core/two_agents_pub_sub.py @@ -25,10 +25,10 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import CancellationToken sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -57,7 +57,7 @@ class ChatCompletionAgent(TypeRoutedAgent): self._termination_word = termination_word @message_handler - async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None: + async def handle_message(self, message: Message, ctx: MessageContext) -> None: self._memory.append(message) if self._termination_word in message.content: return diff --git a/python/samples/demos/chat_room.py b/python/samples/demos/chat_room.py index 18830e54a..e6ea5dba4 100644 --- a/python/samples/demos/chat_room.py +++ b/python/samples/demos/chat_room.py @@ -9,11 +9,12 @@ from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler from agnext.components.memory import ChatMemory from agnext.components.models import ChatCompletionClient, SystemMessage -from agnext.core import AgentInstantiationContext, AgentRuntime, CancellationToken +from agnext.core import AgentInstantiationContext, AgentRuntime sys.path.append(os.path.abspath(os.path.dirname(__file__))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.memory import BufferedChatMemory from common.types import Message, TextMessage from common.utils import convert_messages_to_llm_messages, get_chat_completion_client_from_envs @@ -50,7 +51,7 @@ Use the following JSON format to provide your thought on the latest message and self._client = model_client @message_handler() - async def on_chat_room_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: + async def on_chat_room_message(self, message: TextMessage, ctx: MessageContext) -> None: # Save the message to memory as structured JSON. from_message = TextMessage( content=json.dumps({"sender": message.source, "content": message.content}), source=message.source @@ -82,7 +83,7 @@ class ChatRoomUserAgent(TextualUserAgent): """An agent that is used to receive messages from the runtime.""" @message_handler - async def on_chat_room_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: + async def on_chat_room_message(self, message: TextMessage, ctx: MessageContext) -> None: await self._app.post_runtime_message(message) diff --git a/python/samples/marketing-agents/auditor.py b/python/samples/marketing-agents/auditor.py index 132ee09d3..4ac299dfe 100644 --- a/python/samples/marketing-agents/auditor.py +++ b/python/samples/marketing-agents/auditor.py @@ -1,7 +1,7 @@ from agnext.components import TypeRoutedAgent, message_handler from agnext.components.models import ChatCompletionClient from agnext.components.models._types import SystemMessage -from agnext.core import CancellationToken +from agnext.core import MessageContext from messages import AuditorAlert, AuditText auditor_prompt = """You are an Auditor in a Marketing team @@ -24,7 +24,7 @@ class AuditAgent(TypeRoutedAgent): self._model_client = model_client @message_handler - async def handle_user_chat_input(self, message: AuditText, cancellation_token: CancellationToken) -> None: + async def handle_user_chat_input(self, message: AuditText, ctx: MessageContext) -> None: sys_prompt = auditor_prompt.format(input=message.text) completion = await self._model_client.create(messages=[SystemMessage(content=sys_prompt)]) assert isinstance(completion.content, str) diff --git a/python/samples/marketing-agents/graphic_designer.py b/python/samples/marketing-agents/graphic_designer.py index ae0873ec3..b997a1f88 100644 --- a/python/samples/marketing-agents/graphic_designer.py +++ b/python/samples/marketing-agents/graphic_designer.py @@ -6,7 +6,7 @@ from agnext.components import ( TypeRoutedAgent, message_handler, ) -from agnext.core import CancellationToken +from agnext.core import MessageContext from messages import ArticleCreated, GraphicDesignCreated @@ -21,7 +21,7 @@ class GraphicDesignerAgent(TypeRoutedAgent): self._model = model @message_handler - async def handle_user_chat_input(self, message: ArticleCreated, cancellation_token: CancellationToken) -> None: + async def handle_user_chat_input(self, message: ArticleCreated, ctx: MessageContext) -> None: logger = logging.getLogger("graphic_designer") try: logger.info(f"Asking model to generate an image for the article '{message.article}'.") diff --git a/python/samples/marketing-agents/test_usage.py b/python/samples/marketing-agents/test_usage.py index 3d22ba456..81b03e515 100644 --- a/python/samples/marketing-agents/test_usage.py +++ b/python/samples/marketing-agents/test_usage.py @@ -3,7 +3,7 @@ import os from agnext.application import SingleThreadedAgentRuntime from agnext.components import Image, TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import MessageContext from app import build_app from dotenv import load_dotenv from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated @@ -16,14 +16,14 @@ class Printer(TypeRoutedAgent): super().__init__("") @message_handler - async def handle_graphic_design(self, message: GraphicDesignCreated, cancellation_token: CancellationToken) -> None: + async def handle_graphic_design(self, message: GraphicDesignCreated, ctx: MessageContext) -> None: image = Image.from_uri(message.imageUri) # Save image to random name in current directory image.image.save(os.path.join(os.getcwd(), f"{message.UserId}.png")) print(f"Received GraphicDesignCreated: user {message.UserId}, saved to {message.UserId}.png") @message_handler - async def handle_auditor_alert(self, message: AuditorAlert, cancellation_token: CancellationToken) -> None: + async def handle_auditor_alert(self, message: AuditorAlert, ctx: MessageContext) -> None: print(f"Received AuditorAlert: {message.auditorAlertMessage} for user {message.UserId}") diff --git a/python/samples/patterns/coder_executor.py b/python/samples/patterns/coder_executor.py index 5c230c879..09278e63d 100644 --- a/python/samples/patterns/coder_executor.py +++ b/python/samples/patterns/coder_executor.py @@ -30,10 +30,10 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import CancellationToken sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -88,7 +88,7 @@ Reply "TERMINATE" in the end when everything is done.""" self._session_memory: Dict[str, List[LLMMessage]] = {} @message_handler - async def handle_task(self, message: TaskMessage, cancellation_token: CancellationToken) -> None: + async def handle_task(self, message: TaskMessage, ctx: MessageContext) -> None: # Create a new session. session_id = str(uuid.uuid4()) self._session_memory.setdefault(session_id, []).append(UserMessage(content=message.content, source="user")) @@ -102,13 +102,12 @@ Reply "TERMINATE" in the end when everything is done.""" # Publish the code execution task. await self.publish_message( - CodeExecutionTask(content=response.content, session_id=session_id), cancellation_token=cancellation_token + CodeExecutionTask(content=response.content, session_id=session_id), + cancellation_token=ctx.cancellation_token, ) @message_handler - async def handle_code_execution_result( - self, message: CodeExecutionTaskResult, cancellation_token: CancellationToken - ) -> None: + async def handle_code_execution_result(self, message: CodeExecutionTaskResult, ctx: MessageContext) -> None: # Store the code execution output. self._session_memory[message.session_id].append(UserMessage(content=message.output, source="user")) @@ -121,7 +120,9 @@ Reply "TERMINATE" in the end when everything is done.""" if "TERMINATE" in response.content: # If the task is completed, publish a message with the completion content. - await self.publish_message(TaskCompletion(content=response.content), cancellation_token=cancellation_token) + await self.publish_message( + TaskCompletion(content=response.content), cancellation_token=ctx.cancellation_token + ) print("--------------------") print("Task completed:") print(response.content) @@ -130,7 +131,7 @@ Reply "TERMINATE" in the end when everything is done.""" # Publish the code execution task. await self.publish_message( CodeExecutionTask(content=response.content, session_id=message.session_id), - cancellation_token=cancellation_token, + cancellation_token=ctx.cancellation_token, ) @@ -142,7 +143,7 @@ class Executor(TypeRoutedAgent): self._executor = executor @message_handler - async def handle_code_execution(self, message: CodeExecutionTask, cancellation_token: CancellationToken) -> None: + async def handle_code_execution(self, message: CodeExecutionTask, ctx: MessageContext) -> None: # Extract the code block from the message. code_blocks = self._extract_code_blocks(message.content) if not code_blocks: @@ -151,17 +152,17 @@ class Executor(TypeRoutedAgent): CodeExecutionTaskResult( output="Error: no Markdown code block found.", exit_code=1, session_id=message.session_id ), - cancellation_token=cancellation_token, + cancellation_token=ctx.cancellation_token, ) return # Execute code blocks. result = await self._executor.execute_code_blocks( - code_blocks=code_blocks, cancellation_token=cancellation_token + code_blocks=code_blocks, cancellation_token=ctx.cancellation_token ) # Publish the code execution result. await self.publish_message( CodeExecutionTaskResult(output=result.output, exit_code=result.exit_code, session_id=message.session_id), - cancellation_token=cancellation_token, + cancellation_token=ctx.cancellation_token, ) def _extract_code_blocks(self, markdown_text: str) -> List[CodeBlock]: diff --git a/python/samples/patterns/coder_reviewer.py b/python/samples/patterns/coder_reviewer.py index 8a7797931..ccd0e617a 100644 --- a/python/samples/patterns/coder_reviewer.py +++ b/python/samples/patterns/coder_reviewer.py @@ -29,10 +29,10 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import CancellationToken sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -89,7 +89,7 @@ Respond using the following JSON format: self._model_client = model_client @message_handler - async def handle_code_review_task(self, message: CodeReviewTask, cancellation_token: CancellationToken) -> None: + async def handle_code_review_task(self, message: CodeReviewTask, ctx: MessageContext) -> None: # Format the prompt for the code review. prompt = f"""The problem statement is: {message.code_writing_task} The code is: @@ -155,7 +155,7 @@ Code: async def handle_code_writing_task( self, message: CodeWritingTask, - cancellation_token: CancellationToken, + ctx: MessageContext, ) -> None: # Store the messages in a temporary memory for this request only. session_id = str(uuid.uuid4()) @@ -182,7 +182,7 @@ Code: await self.publish_message(code_review_task) @message_handler - async def handle_code_review_result(self, message: CodeReviewResult, cancellation_token: CancellationToken) -> None: + async def handle_code_review_result(self, message: CodeReviewResult, ctx: MessageContext) -> None: # Store the review result in the session memory. self._session_memory[message.session_id].append(message) # Obtain the request from previous messages. diff --git a/python/samples/patterns/group_chat.py b/python/samples/patterns/group_chat.py index a496f5285..d500554ca 100644 --- a/python/samples/patterns/group_chat.py +++ b/python/samples/patterns/group_chat.py @@ -26,10 +26,11 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import AgentId, CancellationToken +from agnext.core import AgentId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -62,7 +63,7 @@ class RoundRobinGroupChatManager(TypeRoutedAgent): self._round_count = 0 @message_handler - async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None: + async def handle_message(self, message: Message, ctx: MessageContext) -> None: # Select the next speaker in a round-robin fashion speaker = self._participants[self._round_count % len(self._participants)] self._round_count += 1 @@ -87,11 +88,11 @@ class GroupChatParticipant(TypeRoutedAgent): self._memory: List[Message] = [] @message_handler - async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None: + async def handle_message(self, message: Message, ctx: MessageContext) -> None: self._memory.append(message) @message_handler - async def handle_request_to_speak(self, message: RequestToSpeak, cancellation_token: CancellationToken) -> None: + async def handle_request_to_speak(self, message: RequestToSpeak, ctx: MessageContext) -> None: # Generate a response to the last message in the memory if not self._memory: return diff --git a/python/samples/patterns/mixture_of_agents.py b/python/samples/patterns/mixture_of_agents.py index d4b84d397..a99a1fb9b 100644 --- a/python/samples/patterns/mixture_of_agents.py +++ b/python/samples/patterns/mixture_of_agents.py @@ -17,7 +17,7 @@ from typing import Dict, List from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage -from agnext.core import CancellationToken +from agnext.core import MessageContext sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -60,7 +60,7 @@ class ReferenceAgent(TypeRoutedAgent): self._model_client = model_client @message_handler - async def handle_task(self, message: ReferenceAgentTask, cancellation_token: CancellationToken) -> None: + async def handle_task(self, message: ReferenceAgentTask, ctx: MessageContext) -> None: """Handle a task message. This method sends the task to the model and publishes the result.""" task_message = UserMessage(content=message.task, source=self.metadata["type"]) response = await self._model_client.create(self._system_messages + [task_message]) @@ -86,14 +86,14 @@ class AggregatorAgent(TypeRoutedAgent): self._session_results: Dict[str, List[ReferenceAgentTaskResult]] = {} @message_handler - async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> None: + async def handle_task(self, message: AggregatorTask, ctx: MessageContext) -> None: """Handle a task message. This method publishes the task to the reference agents.""" session_id = str(uuid.uuid4()) ref_task = ReferenceAgentTask(session_id=session_id, task=message.task) await self.publish_message(ref_task) @message_handler - async def handle_result(self, message: ReferenceAgentTaskResult, cancellation_token: CancellationToken) -> None: + async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None: """Handle a task result message. Once all results are received, this method aggregates the results and publishes the final result.""" self._session_results.setdefault(message.session_id, []).append(message) diff --git a/python/samples/patterns/multi_agent_debate.py b/python/samples/patterns/multi_agent_debate.py index daac66568..743f9c6fe 100644 --- a/python/samples/patterns/multi_agent_debate.py +++ b/python/samples/patterns/multi_agent_debate.py @@ -48,10 +48,10 @@ from agnext.components.models import ( SystemMessage, UserMessage, ) -from agnext.core import CancellationToken sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs logger = logging.getLogger(__name__) @@ -116,7 +116,7 @@ class MathSolver(TypeRoutedAgent): self._max_round = max_round @message_handler - async def handle_response(self, message: IntermediateSolverResponse, cancellation_token: CancellationToken) -> None: + async def handle_response(self, message: IntermediateSolverResponse, ctx: MessageContext) -> None: if message.solver_name not in self._neighbor_names: return # Add only neighbor's response to the buffer. @@ -143,7 +143,7 @@ class MathSolver(TypeRoutedAgent): self._buffer.pop((message.session_id, message.round)) @message_handler - async def handle_request(self, message: SolverRequest, cancellation_token: CancellationToken) -> None: + async def handle_request(self, message: SolverRequest, ctx: MessageContext) -> None: # Save the question. self._questions[message.session_id] = message.question # Add the question to the memory. @@ -186,7 +186,7 @@ class MathAggregator(TypeRoutedAgent): self._responses: Dict[str, List[FinalSolverResponse]] = {} @message_handler - async def handle_question(self, message: Question, cancellation_token: CancellationToken) -> None: + async def handle_question(self, message: Question, ctx: MessageContext) -> None: prompt = ( f"Can you solve the following math problem?\n{message.content}\n" "Explain your reasoning. Your final answer should be a single numerical number, " @@ -196,9 +196,7 @@ class MathAggregator(TypeRoutedAgent): await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content)) @message_handler - async def handle_final_solver_response( - self, message: FinalSolverResponse, cancellation_token: CancellationToken - ) -> None: + async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None: self._responses.setdefault(message.session_id, []).append(message) if len(self._responses[message.session_id]) == self._num_solvers: # Find the majority answer. diff --git a/python/samples/tool-use/coding_direct.py b/python/samples/tool-use/coding_direct.py index a153a2748..5f278d114 100644 --- a/python/samples/tool-use/coding_direct.py +++ b/python/samples/tool-use/coding_direct.py @@ -30,10 +30,11 @@ from agnext.components.models import ( ) from agnext.components.tool_agent import ToolAgent, ToolException from agnext.components.tools import PythonCodeExecutionTool, Tool, ToolSchema -from agnext.core import AgentId, CancellationToken +from agnext.core import AgentId sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -61,7 +62,7 @@ class ToolUseAgent(TypeRoutedAgent): self._tool_agent = tool_agent @message_handler - async def handle_user_message(self, message: Message, cancellation_token: CancellationToken) -> Message: + async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message: """Handle a user message, execute the model and tools, and returns the response.""" session: List[LLMMessage] = [] session.append(UserMessage(content=message.content, source="User")) @@ -72,7 +73,7 @@ class ToolUseAgent(TypeRoutedAgent): while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content): results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( *[ - self.send_message(call, self._tool_agent, cancellation_token=cancellation_token) + self.send_message(call, self._tool_agent, cancellation_token=ctx.cancellation_token) for call in response.content ], return_exceptions=True, diff --git a/python/samples/tool-use/coding_pub_sub.py b/python/samples/tool-use/coding_pub_sub.py index 02ee67833..656f1abda 100644 --- a/python/samples/tool-use/coding_pub_sub.py +++ b/python/samples/tool-use/coding_pub_sub.py @@ -32,10 +32,10 @@ from agnext.components.models import ( UserMessage, ) from agnext.components.tools import PythonCodeExecutionTool, Tool -from agnext.core import CancellationToken sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from agnext.core import MessageContext from common.utils import get_chat_completion_client_from_envs @@ -69,7 +69,7 @@ class ToolExecutorAgent(TypeRoutedAgent): self._tools = tools @message_handler - async def handle_tool_call(self, message: ToolExecutionTask, cancellation_token: CancellationToken) -> None: + async def handle_tool_call(self, message: ToolExecutionTask, ctx: MessageContext) -> None: """Handle a tool execution task. This method executes the tool and publishes the result.""" # Find the tool tool = next((tool for tool in self._tools if tool.name == message.function_call.name), None) @@ -78,7 +78,7 @@ class ToolExecutorAgent(TypeRoutedAgent): else: try: arguments = json.loads(message.function_call.arguments) - result = await tool.run_json(args=arguments, cancellation_token=cancellation_token) + result = await tool.run_json(args=arguments, cancellation_token=ctx.cancellation_token) result_as_str = tool.return_value_as_string(result) except json.JSONDecodeError: result_as_str = f"Error: Invalid arguments: {message.function_call.arguments}" @@ -112,7 +112,7 @@ class ToolUseAgent(TypeRoutedAgent): self._tool_counter: Dict[str, int] = {} @message_handler - async def handle_user_message(self, message: UserRequest, cancellation_token: CancellationToken) -> None: + async def handle_user_message(self, message: UserRequest, ctx: MessageContext) -> None: """Handle a user message. This method calls the model. If the model response is a string, it publishes the response. If the model response is a list of function calls, it publishes the function calls to the tool executor agent.""" @@ -142,7 +142,7 @@ class ToolUseAgent(TypeRoutedAgent): await self.publish_message(task) @message_handler - async def handle_tool_result(self, message: ToolExecutionTaskResult, cancellation_token: CancellationToken) -> None: + async def handle_tool_result(self, message: ToolExecutionTaskResult, ctx: MessageContext) -> None: """Handle a tool execution result. This method aggregates the tool results and calls the model again to get another response. If the response is a string, it publishes the response. If the response is a list of function calls, it publishes diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index 839d0d219..a30840dec 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -12,6 +12,8 @@ from dataclasses import dataclass from enum import Enum from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast +from agnext.core import MessageContext + from ..core import ( MESSAGE_TYPE_REGISTRY, Agent, @@ -274,9 +276,15 @@ class SingleThreadedAgentRuntime(AgentRuntime): # ) # ) recipient_agent = await self._get_agent(recipient) + message_context = MessageContext( + sender=message_envelope.sender, + topic_id=None, + is_rpc=True, + cancellation_token=message_envelope.cancellation_token, + ) response = await recipient_agent.on_message( message_envelope.message, - cancellation_token=message_envelope.cancellation_token, + ctx=message_context, ) except BaseException as e: message_envelope.future.set_exception(e) @@ -317,11 +325,17 @@ class SingleThreadedAgentRuntime(AgentRuntime): # delivery_stage=DeliveryStage.DELIVER, # ) # ) - + message_context = MessageContext( + sender=message_envelope.sender, + # TODO: topic_id + topic_id=None, + is_rpc=False, + cancellation_token=message_envelope.cancellation_token, + ) agent = await self._get_agent(agent_id) future = agent.on_message( message_envelope.message, - cancellation_token=message_envelope.cancellation_token, + ctx=message_context, ) responses.append(future) diff --git a/python/src/agnext/components/_closure_agent.py b/python/src/agnext/components/_closure_agent.py index 1636bd2fa..21ffd9511 100644 --- a/python/src/agnext/components/_closure_agent.py +++ b/python/src/agnext/components/_closure_agent.py @@ -1,12 +1,13 @@ import inspect from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_type_hints +from agnext.core import MessageContext + from ..core._agent import Agent from ..core._agent_id import AgentId from ..core._agent_instantiation import AgentInstantiationContext from ..core._agent_metadata import AgentMetadata from ..core._agent_runtime import AgentRuntime -from ..core._cancellation_token import CancellationToken from ..core._serialization import MESSAGE_TYPE_REGISTRY from ..core.exceptions import CantHandleException from ._type_helpers import get_types @@ -15,7 +16,7 @@ T = TypeVar("T") def get_subscriptions_from_closure( - closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]], + closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]], ) -> Sequence[type]: args = inspect.getfullargspec(closure)[0] if len(args) != 4: @@ -44,7 +45,7 @@ def get_subscriptions_from_closure( class ClosureAgent(Agent): def __init__( - self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]] + self, description: str, closure: Callable[[AgentRuntime, AgentId, T, MessageContext], Awaitable[Any]] ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -82,12 +83,12 @@ class ClosureAgent(Agent): def runtime(self) -> AgentRuntime: return self._runtime - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + async def on_message(self, message: Any, ctx: MessageContext) -> Any: if MESSAGE_TYPE_REGISTRY.type_name(message) not in self._subscriptions: raise CantHandleException( f"Message type {type(message)} not in target types {self._subscriptions} of {self.id}" ) - return await self._closure(self._runtime, self._id, message, cancellation_token) + return await self._closure(self._runtime, self._id, message, ctx) def save_state(self) -> Mapping[str, Any]: raise ValueError("save_state not implemented for ClosureAgent") diff --git a/python/src/agnext/components/_type_routed_agent.py b/python/src/agnext/components/_type_routed_agent.py index 2c3d9c655..ba5ee5384 100644 --- a/python/src/agnext/components/_type_routed_agent.py +++ b/python/src/agnext/components/_type_routed_agent.py @@ -17,7 +17,7 @@ from typing import ( runtime_checkable, ) -from ..core import MESSAGE_TYPE_REGISTRY, BaseAgent, CancellationToken +from ..core import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext from ..core.exceptions import CantHandleException from ._type_helpers import AnyType, get_types @@ -36,7 +36,7 @@ class MessageHandler(Protocol[ReceivesT, ProducesT]): produces_types: Sequence[type] is_message_handler: Literal[True] - async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ... + async def __call__(self, message: ReceivesT, ctx: MessageContext) -> ProducesT: ... # NOTE: this works on concrete types and not inheritance @@ -45,7 +45,7 @@ class MessageHandler(Protocol[ReceivesT, ProducesT]): @overload def message_handler( - func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]], + func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], ) -> MessageHandler[ReceivesT, ProducesT]: ... @@ -55,24 +55,24 @@ def message_handler( *, strict: bool = ..., ) -> Callable[ - [Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]], + [Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], MessageHandler[ReceivesT, ProducesT], ]: ... def message_handler( - func: None | Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]] = None, + func: None | Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]] = None, *, strict: bool = True, ) -> ( Callable[ - [Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]], + [Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]]], MessageHandler[ReceivesT, ProducesT], ] | MessageHandler[ReceivesT, ProducesT] ): def decorator( - func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]], + func: Callable[[Any, ReceivesT, MessageContext], Coroutine[Any, Any, ProducesT]], ) -> MessageHandler[ReceivesT, ProducesT]: type_hints = get_type_hints(func) if "message" not in type_hints: @@ -95,14 +95,14 @@ def message_handler( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: + async def wrapper(self: Any, message: ReceivesT, ctx: MessageContext) -> ProducesT: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") else: logger.warning(f"Message type {type(message)} not in target types {target_types}") - return_value = await func(self, message, cancellation_token) + return_value = await func(self, message, ctx) if AnyType not in return_types and type(return_value) not in return_types: if strict: @@ -132,7 +132,7 @@ class TypeRoutedAgent(BaseAgent): # Self is already bound to the handlers self._handlers: Dict[ Type[Any], - Callable[[Any, CancellationToken], Coroutine[Any, Any, Any | None]], + Callable[[Any, MessageContext], Coroutine[Any, Any, Any | None]], ] = {} for attr in dir(self): @@ -149,13 +149,13 @@ class TypeRoutedAgent(BaseAgent): subscriptions_str = [MESSAGE_TYPE_REGISTRY.type_name(message_type) for message_type in subscriptions] super().__init__(description, subscriptions_str) - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: + async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: key_type: Type[Any] = type(message) # type: ignore handler = self._handlers.get(key_type) # type: ignore if handler is not None: - return await handler(message, cancellation_token) + return await handler(message, ctx) else: - return await self.on_unhandled_message(message, cancellation_token) + return await self.on_unhandled_message(message, ctx) - async def on_unhandled_message(self, message: Any, cancellation_token: CancellationToken) -> NoReturn: + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: raise CantHandleException(f"Unhandled message: {message}") diff --git a/python/src/agnext/components/tool_agent/_tool_agent.py b/python/src/agnext/components/tool_agent/_tool_agent.py index 7e6056f55..2f18c1b14 100644 --- a/python/src/agnext/components/tool_agent/_tool_agent.py +++ b/python/src/agnext/components/tool_agent/_tool_agent.py @@ -2,7 +2,7 @@ import json from dataclasses import dataclass from typing import List -from ...core import CancellationToken +from ...core import MessageContext from .. import FunctionCall, TypeRoutedAgent, message_handler from ..models import FunctionExecutionResult from ..tools import Tool @@ -60,9 +60,7 @@ class ToolAgent(TypeRoutedAgent): return self._tools @message_handler - async def handle_function_call( - self, message: FunctionCall, cancellation_token: CancellationToken - ) -> FunctionExecutionResult: + async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) -> FunctionExecutionResult: """Handles a `FunctionCall` message by executing the requested tool with the provided arguments. Args: @@ -83,7 +81,7 @@ class ToolAgent(TypeRoutedAgent): else: try: arguments = json.loads(message.arguments) - result = await tool.run_json(args=arguments, cancellation_token=cancellation_token) + result = await tool.run_json(args=arguments, cancellation_token=ctx.cancellation_token) result_as_str = tool.return_value_as_string(result) except json.JSONDecodeError as e: raise InvalidToolArgumentsException( diff --git a/python/src/agnext/core/__init__.py b/python/src/agnext/core/__init__.py index 45e155884..1ce9345b9 100644 --- a/python/src/agnext/core/__init__.py +++ b/python/src/agnext/core/__init__.py @@ -11,6 +11,7 @@ from ._agent_proxy import AgentProxy from ._agent_runtime import AgentRuntime from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken +from ._message_context import MessageContext from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer from ._subscription import Subscription from ._topic import TopicId @@ -30,4 +31,5 @@ __all__ = [ "TypeDeserializer", "TopicId", "Subscription", + "MessageContext", ] diff --git a/python/src/agnext/core/_agent.py b/python/src/agnext/core/_agent.py index 4bb07ed3c..376efa254 100644 --- a/python/src/agnext/core/_agent.py +++ b/python/src/agnext/core/_agent.py @@ -2,7 +2,7 @@ from typing import Any, Mapping, Protocol, runtime_checkable from ._agent_id import AgentId from ._agent_metadata import AgentMetadata -from ._cancellation_token import CancellationToken +from ._message_context import MessageContext @runtime_checkable @@ -17,12 +17,12 @@ class Agent(Protocol): """ID of the agent.""" ... - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + async def on_message(self, message: Any, ctx: MessageContext) -> Any: """Message handler for the agent. This should only be called by the runtime, not by other agents. Args: message (Any): Received message. Type is one of the types in `subscriptions`. - cancellation_token (CancellationToken): Cancellation token for the message. + ctx (MessageContext): Context of the message. Returns: Any: Response to the message. Can be None. diff --git a/python/src/agnext/core/_base_agent.py b/python/src/agnext/core/_base_agent.py index 6963ee1b0..348ed9559 100644 --- a/python/src/agnext/core/_base_agent.py +++ b/python/src/agnext/core/_base_agent.py @@ -8,6 +8,7 @@ from ._agent_instantiation import AgentInstantiationContext from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._cancellation_token import CancellationToken +from ._message_context import MessageContext class BaseAgent(ABC, Agent): @@ -50,7 +51,7 @@ class BaseAgent(ABC, Agent): return self._runtime @abstractmethod - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ... + async def on_message(self, message: Any, ctx: MessageContext) -> Any: ... async def send_message( self, diff --git a/python/src/agnext/core/_message_context.py b/python/src/agnext/core/_message_context.py new file mode 100644 index 000000000..0a2c2973b --- /dev/null +++ b/python/src/agnext/core/_message_context.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from ._agent_id import AgentId +from ._cancellation_token import CancellationToken +from ._topic import TopicId + + +@dataclass +class MessageContext: + sender: AgentId | None + topic_id: TopicId | None + is_rpc: bool + cancellation_token: CancellationToken diff --git a/python/src/agnext/worker/worker_runtime.py b/python/src/agnext/worker/worker_runtime.py index 71bfc71de..2b4a5ce45 100644 --- a/python/src/agnext/worker/worker_runtime.py +++ b/python/src/agnext/worker/worker_runtime.py @@ -31,7 +31,7 @@ import grpc from grpc.aio import StreamStreamCall from typing_extensions import Self -from agnext.core import MESSAGE_TYPE_REGISTRY +from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentProxy, AgentRuntime, CancellationToken from .protos import AgentId as AgentIdProto @@ -248,8 +248,16 @@ class WorkerAgentRuntime(AgentRuntime): ]: logger.info("Sending message to %s", agent_id) agent = await self._get_agent(agent_id) + message_context = MessageContext( + # TODO: should sender be in the proto even for published events? + sender=None, + # TODO: topic_id + topic_id=None, + is_rpc=False, + cancellation_token=CancellationToken(), + ) try: - await agent.on_message(message, CancellationToken()) + await agent.on_message(message, ctx=message_context) logger.info("%s handled event %s", agent_id, message) except Exception as e: event_logger.error("Error handling message", exc_info=e) diff --git a/python/teams/team-one/src/team_one/agents/base_agent.py b/python/teams/team-one/src/team_one/agents/base_agent.py index 347c4581a..ae5b44ba4 100644 --- a/python/teams/team-one/src/team_one/agents/base_agent.py +++ b/python/teams/team-one/src/team_one/agents/base_agent.py @@ -4,7 +4,7 @@ from typing import Any from agnext.application.logging import EVENT_LOGGER_NAME from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken +from agnext.core import MessageContext from team_one.messages import ( AgentEvent, @@ -32,25 +32,25 @@ class TeamOneBaseAgent(TypeRoutedAgent): if not self._handle_messages_concurrently: # TODO: make it possible to stop - self._message_queue = asyncio.Queue[tuple[TeamOneMessages, CancellationToken, asyncio.Future[Any]]]() + self._message_queue = asyncio.Queue[tuple[TeamOneMessages, MessageContext, asyncio.Future[Any]]]() self._processing_task = asyncio.create_task(self._process()) async def _process(self) -> None: while True: - message, cancellation_token, future = await self._message_queue.get() - if cancellation_token.is_cancelled(): + message, ctx, future = await self._message_queue.get() + if ctx.cancellation_token.is_cancelled(): # TODO: Do we need to resolve the future here? continue try: if isinstance(message, RequestReplyMessage): - await self._handle_request_reply(message, cancellation_token) + await self._handle_request_reply(message, ctx) elif isinstance(message, BroadcastMessage): - await self._handle_broadcast(message, cancellation_token) + await self._handle_broadcast(message, ctx) elif isinstance(message, ResetMessage): - await self._handle_reset(message, cancellation_token) + await self._handle_reset(message, ctx) elif isinstance(message, DeactivateMessage): - await self._handle_deactivate(message, cancellation_token) + await self._handle_deactivate(message, ctx) else: raise ValueError("Unknown message type.") future.set_result(None) @@ -61,35 +61,35 @@ class TeamOneBaseAgent(TypeRoutedAgent): async def handle_incoming_message( self, message: BroadcastMessage | ResetMessage | DeactivateMessage | RequestReplyMessage, - cancellation_token: CancellationToken, + ctx: MessageContext, ) -> None: if not self._enabled: return if self._handle_messages_concurrently: if isinstance(message, RequestReplyMessage): - await self._handle_request_reply(message, cancellation_token) + await self._handle_request_reply(message, ctx) elif isinstance(message, BroadcastMessage): - await self._handle_broadcast(message, cancellation_token) + await self._handle_broadcast(message, ctx) elif isinstance(message, ResetMessage): - await self._handle_reset(message, cancellation_token) + await self._handle_reset(message, ctx) elif isinstance(message, DeactivateMessage): - await self._handle_deactivate(message, cancellation_token) + await self._handle_deactivate(message, ctx) else: future = asyncio.Future[Any]() - await self._message_queue.put((message, cancellation_token, future)) + await self._message_queue.put((message, ctx, future)) await future - async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None: raise NotImplementedError() - async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None: + async def _handle_reset(self, message: ResetMessage, ctx: MessageContext) -> None: raise NotImplementedError() - async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None: + async def _handle_request_reply(self, message: RequestReplyMessage, ctx: MessageContext) -> None: raise NotImplementedError() - async def _handle_deactivate(self, message: DeactivateMessage, cancellation_token: CancellationToken) -> None: + async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None: """Handle a deactivate message.""" self._enabled = False logger.info( diff --git a/python/teams/team-one/src/team_one/agents/base_orchestrator.py b/python/teams/team-one/src/team_one/agents/base_orchestrator.py index 900ccb268..c039bfa79 100644 --- a/python/teams/team-one/src/team_one/agents/base_orchestrator.py +++ b/python/teams/team-one/src/team_one/agents/base_orchestrator.py @@ -4,7 +4,7 @@ from typing import List, Optional from agnext.application.logging import EVENT_LOGGER_NAME from agnext.components.models import AssistantMessage, LLMMessage, UserMessage -from agnext.core import AgentProxy, CancellationToken +from agnext.core import AgentProxy, CancellationToken, MessageContext from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage, ResetMessage from ..utils import message_content_to_str @@ -29,7 +29,7 @@ class BaseOrchestrator(TeamOneBaseAgent): self._num_rounds = 0 self._start_time: float = -1.0 - async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None: """Handle an incoming message.""" # First broadcast sets the timer @@ -100,9 +100,9 @@ class BaseOrchestrator(TeamOneBaseAgent): def get_max_rounds(self) -> int: return self._max_rounds - async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None: + async def _handle_reset(self, message: ResetMessage, ctx: MessageContext) -> None: """Handle a reset message.""" - await self._reset(cancellation_token) + await self._reset(ctx.cancellation_token) async def _reset(self, cancellation_token: CancellationToken) -> None: pass diff --git a/python/teams/team-one/src/team_one/agents/base_worker.py b/python/teams/team-one/src/team_one/agents/base_worker.py index af4d13327..a3e089c09 100644 --- a/python/teams/team-one/src/team_one/agents/base_worker.py +++ b/python/teams/team-one/src/team_one/agents/base_worker.py @@ -5,7 +5,7 @@ from agnext.components.models import ( LLMMessage, UserMessage, ) -from agnext.core import CancellationToken +from agnext.core import CancellationToken, MessageContext from team_one.messages import ( BroadcastMessage, @@ -29,17 +29,17 @@ class BaseWorker(TeamOneBaseAgent): super().__init__(description, handle_messages_concurrently=handle_messages_concurrently) self._chat_history: List[LLMMessage] = [] - async def _handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None: assert isinstance(message.content, UserMessage) self._chat_history.append(message.content) - async def _handle_reset(self, message: ResetMessage, cancellation_token: CancellationToken) -> None: + async def _handle_reset(self, message: ResetMessage, ctx: MessageContext) -> None: """Handle a reset message.""" - await self._reset(cancellation_token) + await self._reset(ctx.cancellation_token) - async def _handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None: + async def _handle_request_reply(self, message: RequestReplyMessage, ctx: MessageContext) -> None: """Respond to a reply request.""" - request_halt, response = await self._generate_reply(cancellation_token) + request_halt, response = await self._generate_reply(ctx.cancellation_token) assistant_message = AssistantMessage(content=message_content_to_str(response), source=self.metadata["type"]) self._chat_history.append(assistant_message) diff --git a/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py b/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py index 568d806a5..293e50f7f 100644 --- a/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py +++ b/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py @@ -30,12 +30,7 @@ from playwright._impl._errors import TimeoutError from playwright.async_api import BrowserContext, Download, Page, Playwright, async_playwright # TODO: Fix mdconvert -from ...markdown_browser import ( # type: ignore - DocumentConverterResult, # type: ignore - FileConversionException, # type: ignore - MarkdownConverter, # type: ignore - UnsupportedFormatException, # type: ignore -) +from ...markdown_browser import MarkdownConverter # type: ignore from ...messages import UserContent, WebSurferEvent from ...utils import SentinelMeta, message_content_to_str from ..base_worker import BaseWorker diff --git a/python/teams/team-one/src/team_one/agents/reflex_agents.py b/python/teams/team-one/src/team_one/agents/reflex_agents.py index 2c51f6c9a..cd0df896e 100644 --- a/python/teams/team-one/src/team_one/agents/reflex_agents.py +++ b/python/teams/team-one/src/team_one/agents/reflex_agents.py @@ -1,6 +1,6 @@ from agnext.components import TypeRoutedAgent, message_handler from agnext.components.models import UserMessage -from agnext.core import CancellationToken +from agnext.core import MessageContext from ..messages import BroadcastMessage, RequestReplyMessage @@ -10,14 +10,12 @@ class ReflexAgent(TypeRoutedAgent): super().__init__(description) @message_handler - async def handle_incoming_message(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None: + async def handle_incoming_message(self, message: BroadcastMessage, ctx: MessageContext) -> None: """Handle an incoming message.""" pass @message_handler - async def handle_request_reply_message( - self, message: RequestReplyMessage, cancellation_token: CancellationToken - ) -> None: + async def handle_request_reply_message(self, message: RequestReplyMessage, ctx: MessageContext) -> None: name = self.metadata["type"] response_message = UserMessage( @@ -25,4 +23,4 @@ class ReflexAgent(TypeRoutedAgent): source=name, ) - await self.publish_message(BroadcastMessage(response_message), cancellation_token=cancellation_token) + await self.publish_message(BroadcastMessage(response_message)) diff --git a/python/tests/test_cancellation.py b/python/tests/test_cancellation.py index a077a2c9b..a5ac020bf 100644 --- a/python/tests/test_cancellation.py +++ b/python/tests/test_cancellation.py @@ -5,6 +5,7 @@ import pytest from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler from agnext.core import AgentId, CancellationToken +from agnext.core import MessageContext @dataclass @@ -22,10 +23,10 @@ class LongRunningAgent(TypeRoutedAgent): self.cancelled = False @message_handler - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) - cancellation_token.link_future(sleep) + ctx.cancellation_token.link_future(sleep) try: await sleep return MessageType() @@ -41,9 +42,9 @@ class NestingLongRunningAgent(TypeRoutedAgent): self._nested_agent = nested_agent @message_handler - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.called = True - response = self.send_message(message, self._nested_agent, cancellation_token=cancellation_token) + response = self.send_message(message, self._nested_agent, cancellation_token=ctx.cancellation_token) try: val = await response assert isinstance(val, MessageType) diff --git a/python/tests/test_closure_agent.py b/python/tests/test_closure_agent.py index cb5512041..ce2f93d13 100644 --- a/python/tests/test_closure_agent.py +++ b/python/tests/test_closure_agent.py @@ -5,13 +5,15 @@ from dataclasses import dataclass import pytest from agnext.application import SingleThreadedAgentRuntime -from agnext.core import AgentRuntime, AgentId, CancellationToken +from agnext.core import AgentRuntime, AgentId from agnext.components import ClosureAgent import asyncio +from agnext.core import MessageContext + @dataclass class Message: content: str @@ -24,7 +26,7 @@ async def test_register_receives_publish() -> None: queue = asyncio.Queue[tuple[str, str]]() - async def log_message(_runtime: AgentRuntime, id: AgentId, message: Message, cancellation_token: CancellationToken) -> None: + async def log_message(_runtime: AgentRuntime, id: AgentId, message: Message, ctx: MessageContext) -> None: key = id.key await queue.put((key, message.content)) diff --git a/python/tests/test_state.py b/python/tests/test_state.py index b88f79b20..8bc562d1c 100644 --- a/python/tests/test_state.py +++ b/python/tests/test_state.py @@ -2,7 +2,7 @@ from typing import Any, Mapping, Sequence import pytest from agnext.application import SingleThreadedAgentRuntime -from agnext.core import BaseAgent, CancellationToken +from agnext.core import BaseAgent, MessageContext class StatefulAgent(BaseAgent): @@ -14,7 +14,7 @@ class StatefulAgent(BaseAgent): def subscriptions(self) -> Sequence[type]: return [] - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> None: + async def on_message(self, message: Any, ctx: MessageContext) -> None: raise NotImplementedError def save_state(self) -> Mapping[str, Any]: diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 4a883ee2c..7298f197a 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Union from agnext.components._type_routed_agent import message_handler from agnext.components._type_helpers import AnyType, get_types -from agnext.core import CancellationToken +from agnext.core import MessageContext def test_get_types() -> None: @@ -21,11 +21,11 @@ def test_handler() -> None: class HandlerClass: @message_handler() - async def handler(self, message: int, cancellation_token: CancellationToken) -> Any: + async def handler(self, message: int, ctx: MessageContext) -> Any: return None @message_handler() - async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None: + async def handler2(self, message: str | bool, ctx: MessageContext) -> None: return None assert HandlerClass.handler.target_types == [int] @@ -36,5 +36,5 @@ def test_handler() -> None: class HandlerClass: @message_handler() - async def handler(self, message: int, cancellation_token: CancellationToken) -> Any: + async def handler(self, message: int, ctx: MessageContext) -> Any: return None diff --git a/python/tests/test_utils/__init__.py b/python/tests/test_utils/__init__.py index 69c81039f..9ce86897f 100644 --- a/python/tests/test_utils/__init__.py +++ b/python/tests/test_utils/__init__.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from typing import Any from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import CancellationToken, BaseAgent +from agnext.core import BaseAgent +from agnext.core import MessageContext @dataclass @@ -20,7 +21,7 @@ class LoopbackAgent(TypeRoutedAgent): @message_handler - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.num_calls += 1 return message @@ -31,9 +32,9 @@ class CascadingAgent(TypeRoutedAgent): super().__init__("A cascading agent.") self.num_calls = 0 self.max_rounds = max_rounds - + @message_handler - async def on_new_message(self, message: CascadingMessageType, cancellation_token: CancellationToken) -> None: + async def on_new_message(self, message: CascadingMessageType, ctx: MessageContext) -> None: self.num_calls += 1 if message.round == self.max_rounds: return @@ -43,5 +44,5 @@ class NoopAgent(BaseAgent): def __init__(self) -> None: super().__init__("A no op agent", []) - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + async def on_message(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError