mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Move grpc runtimes to ext, flatten application (#4553)
* Move grpc runtimes to ext, flatten application * rename to grpc * fmt
This commit is contained in:
@@ -9,7 +9,7 @@ from openai import AzureOpenAI
|
||||
from typing import List
|
||||
|
||||
from autogen_core import AgentId, AgentProxy, TopicId
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_core.components.models import (
|
||||
ChatCompletionClient,
|
||||
|
||||
@@ -9,7 +9,7 @@ from openai import AzureOpenAI
|
||||
from typing import List
|
||||
|
||||
from autogen_core import AgentId, AgentProxy, TopicId
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_core.components.models import (
|
||||
ChatCompletionClient,
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
|
||||
from autogen_core import AgentId, AgentProxy, TopicId
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_core import DefaultSubscription, DefaultTopicId
|
||||
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
|
||||
|
||||
@@ -8,7 +8,7 @@ import nltk
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from autogen_core import AgentId, AgentProxy, TopicId
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_core import DefaultSubscription, DefaultTopicId
|
||||
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
|
||||
|
||||
@@ -12,10 +12,10 @@ from autogen_core import (
|
||||
CancellationToken,
|
||||
ClosureAgent,
|
||||
MessageContext,
|
||||
SingleThreadedAgentRuntime,
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core._closure_agent import ClosureContext
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
|
||||
@@ -5,8 +5,14 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.teams._group_chat._sequential_routed_agent import SequentialRoutedAgent
|
||||
from autogen_core import AgentId, DefaultTopicId, MessageContext, default_subscription, message_handler
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
DefaultTopicId,
|
||||
MessageContext,
|
||||
SingleThreadedAgentRuntime,
|
||||
default_subscription,
|
||||
message_handler,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -25,8 +25,14 @@
|
||||
"import asyncio\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from autogen_core import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId, MessageContext\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime"
|
||||
"from autogen_core import (\n",
|
||||
" ClosureAgent,\n",
|
||||
" ClosureContext,\n",
|
||||
" DefaultSubscription,\n",
|
||||
" DefaultTopicId,\n",
|
||||
" MessageContext,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -35,15 +35,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, Callable, List, Literal\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
|
||||
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
|
||||
"from langchain_core.messages import HumanMessage, SystemMessage\n",
|
||||
"from langchain_core.tools import tool # pyright: ignore\n",
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -41,8 +41,7 @@
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List, Optional\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
|
||||
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
|
||||
"from llama_index.core import Settings\n",
|
||||
"from llama_index.core.agent import ReActAgent\n",
|
||||
|
||||
@@ -39,8 +39,15 @@
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import (\n",
|
||||
" AgentId,\n",
|
||||
" DefaultTopicId,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" default_subscription,\n",
|
||||
" message_handler,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" AssistantMessage,\n",
|
||||
|
||||
@@ -386,7 +386,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import SingleThreadedAgentRuntime\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"await OpenAIAssistantAgent.register(\n",
|
||||
|
||||
@@ -22,8 +22,15 @@
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import (\n",
|
||||
" AgentId,\n",
|
||||
" DefaultTopicId,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" default_subscription,\n",
|
||||
" message_handler,\n",
|
||||
")\n",
|
||||
"from autogen_core.base.intervention import DefaultInterventionHandler"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -1,283 +1,290 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# User Approval for Tool Execution using Intervention Handler\n",
|
||||
"\n",
|
||||
"This cookbook shows how to intercept the tool execution using\n",
|
||||
"an intervention hanlder, and prompt the user for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, List\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, AgentType, FunctionCall, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's define a simple message type that carries a string content."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a simple tool use agent that is capable of using tools through a\n",
|
||||
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
|
||||
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" description: str,\n",
|
||||
" system_messages: List[SystemMessage],\n",
|
||||
" model_client: ChatCompletionClient,\n",
|
||||
" tool_schema: List[ToolSchema],\n",
|
||||
" tool_agent_type: AgentType,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__(description)\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._system_messages = system_messages\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
|
||||
" # Use the tool agent to execute the tools, and get the output messages.\n",
|
||||
" output_messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Extract the final response from the output messages.\n",
|
||||
" final_response = output_messages[-1].content\n",
|
||||
" assert isinstance(final_response, str)\n",
|
||||
" return Message(content=final_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
|
||||
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
|
||||
"to prompt the user for permission to execute the tool.\n",
|
||||
"\n",
|
||||
"Let's create an intervention handler that intercepts the messages and prompts\n",
|
||||
"user for before allowing the tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
|
||||
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
|
||||
" if isinstance(message, FunctionCall):\n",
|
||||
" # Request user prompt for tool execution.\n",
|
||||
" user_input = input(\n",
|
||||
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
|
||||
" )\n",
|
||||
" if user_input.strip().lower() != \"y\":\n",
|
||||
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
|
||||
" return message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we can create a runtime with the intervention handler registered."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the runtime with the intervention handler.\n",
|
||||
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this example, we will use a tool for Python code execution.\n",
|
||||
"First, we create a Docker-based command-line code executor\n",
|
||||
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
|
||||
"and then use it to instantiate a built-in Python code execution tool\n",
|
||||
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
|
||||
"that runs code in a Docker container."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the docker executor for the Python code execution tool.\n",
|
||||
"docker_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"\n",
|
||||
"# Create the Python code execution tool.\n",
|
||||
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register the agents with tools and tool schema."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_enabled_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Register agents.\n",
|
||||
"tool_agent_type = await ToolAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_executor_agent\",\n",
|
||||
" lambda: ToolAgent(\n",
|
||||
" description=\"Tool Executor Agent\",\n",
|
||||
" tools=[python_tool],\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_enabled_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" description=\"Tool Use Agent\",\n",
|
||||
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||
" tool_schema=[python_tool.schema],\n",
|
||||
" tool_agent_type=tool_agent_type,\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
|
||||
"The intervention handler will prompt you for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The output of the code is: **Hello, World!**\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start the runtime and the docker executor.\n",
|
||||
"await docker_executor.start()\n",
|
||||
"runtime.start()\n",
|
||||
"\n",
|
||||
"# Send a task to the tool user.\n",
|
||||
"response = await runtime.send_message(\n",
|
||||
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
|
||||
")\n",
|
||||
"print(response.content)\n",
|
||||
"\n",
|
||||
"# Stop the runtime and the docker executor.\n",
|
||||
"await runtime.stop()\n",
|
||||
"await docker_executor.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# User Approval for Tool Execution using Intervention Handler\n",
|
||||
"\n",
|
||||
"This cookbook shows how to intercept the tool execution using\n",
|
||||
"an intervention hanlder, and prompt the user for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, List\n",
|
||||
"\n",
|
||||
"from autogen_core import (\n",
|
||||
" AgentId,\n",
|
||||
" AgentType,\n",
|
||||
" FunctionCall,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" message_handler,\n",
|
||||
")\n",
|
||||
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's define a simple message type that carries a string content."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a simple tool use agent that is capable of using tools through a\n",
|
||||
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
|
||||
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" description: str,\n",
|
||||
" system_messages: List[SystemMessage],\n",
|
||||
" model_client: ChatCompletionClient,\n",
|
||||
" tool_schema: List[ToolSchema],\n",
|
||||
" tool_agent_type: AgentType,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__(description)\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._system_messages = system_messages\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
|
||||
" # Use the tool agent to execute the tools, and get the output messages.\n",
|
||||
" output_messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Extract the final response from the output messages.\n",
|
||||
" final_response = output_messages[-1].content\n",
|
||||
" assert isinstance(final_response, str)\n",
|
||||
" return Message(content=final_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
|
||||
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
|
||||
"to prompt the user for permission to execute the tool.\n",
|
||||
"\n",
|
||||
"Let's create an intervention handler that intercepts the messages and prompts\n",
|
||||
"user for before allowing the tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
|
||||
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
|
||||
" if isinstance(message, FunctionCall):\n",
|
||||
" # Request user prompt for tool execution.\n",
|
||||
" user_input = input(\n",
|
||||
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
|
||||
" )\n",
|
||||
" if user_input.strip().lower() != \"y\":\n",
|
||||
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
|
||||
" return message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we can create a runtime with the intervention handler registered."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the runtime with the intervention handler.\n",
|
||||
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this example, we will use a tool for Python code execution.\n",
|
||||
"First, we create a Docker-based command-line code executor\n",
|
||||
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
|
||||
"and then use it to instantiate a built-in Python code execution tool\n",
|
||||
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
|
||||
"that runs code in a Docker container."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the docker executor for the Python code execution tool.\n",
|
||||
"docker_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"\n",
|
||||
"# Create the Python code execution tool.\n",
|
||||
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register the agents with tools and tool schema."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_enabled_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Register agents.\n",
|
||||
"tool_agent_type = await ToolAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_executor_agent\",\n",
|
||||
" lambda: ToolAgent(\n",
|
||||
" description=\"Tool Executor Agent\",\n",
|
||||
" tools=[python_tool],\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_enabled_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" description=\"Tool Use Agent\",\n",
|
||||
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||
" tool_schema=[python_tool.schema],\n",
|
||||
" tool_agent_type=tool_agent_type,\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
|
||||
"The intervention handler will prompt you for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The output of the code is: **Hello, World!**\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start the runtime and the docker executor.\n",
|
||||
"await docker_executor.start()\n",
|
||||
"runtime.start()\n",
|
||||
"\n",
|
||||
"# Send a task to the tool user.\n",
|
||||
"response = await runtime.send_message(\n",
|
||||
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
|
||||
")\n",
|
||||
"print(response.content)\n",
|
||||
"\n",
|
||||
"# Stop the runtime and the docker executor.\n",
|
||||
"await runtime.stop()\n",
|
||||
"await docker_executor.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,13 +34,13 @@
|
||||
" DefaultTopicId,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" TopicId,\n",
|
||||
" TypeSubscription,\n",
|
||||
" default_subscription,\n",
|
||||
" message_handler,\n",
|
||||
" type_subscription,\n",
|
||||
")\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime"
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -78,11 +78,11 @@
|
||||
" Image,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" TopicId,\n",
|
||||
" TypeSubscription,\n",
|
||||
" message_handler,\n",
|
||||
")\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" AssistantMessage,\n",
|
||||
" ChatCompletionClient,\n",
|
||||
|
||||
@@ -56,8 +56,15 @@
|
||||
"import uuid\n",
|
||||
"from typing import List, Tuple\n",
|
||||
"\n",
|
||||
"from autogen_core import FunctionCall, MessageContext, RoutedAgent, TopicId, TypeSubscription, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import (\n",
|
||||
" FunctionCall,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" TopicId,\n",
|
||||
" TypeSubscription,\n",
|
||||
" message_handler,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" AssistantMessage,\n",
|
||||
" ChatCompletionClient,\n",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -441,8 +441,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_core import DefaultTopicId\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import DefaultTopicId, SingleThreadedAgentRuntime\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
|
||||
@@ -18,7 +18,7 @@ The key can correspond to a user id, a session id, or could just be "default" if
|
||||
|
||||
## How do I increase the GRPC message size?
|
||||
|
||||
If you need to provide custom gRPC options, such as overriding the `max_send_message_length` and `max_receive_message_length`, you can define an `extra_grpc_config` variable and pass it to both the `WorkerAgentRuntimeHost` and `WorkerAgentRuntime` instances.
|
||||
If you need to provide custom gRPC options, such as overriding the `max_send_message_length` and `max_receive_message_length`, you can define an `extra_grpc_config` variable and pass it to both the `GrpcWorkerAgentRuntimeHost` and `GrpcWorkerAgentRuntime` instances.
|
||||
|
||||
```python
|
||||
# Define custom gRPC options
|
||||
@@ -27,10 +27,10 @@ extra_grpc_config = [
|
||||
("grpc.max_receive_message_length", new_max_size),
|
||||
]
|
||||
|
||||
# Create instances of WorkerAgentRuntimeHost and WorkerAgentRuntime with the custom gRPC options
|
||||
# Create instances of GrpcWorkerAgentRuntimeHost and GrpcWorkerAgentRuntime with the custom gRPC options
|
||||
|
||||
host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
```
|
||||
|
||||
**Note**: When `WorkerAgentRuntime` creates a host connection for the clients, it uses `DEFAULT_GRPC_CONFIG` from `HostConnection` class as default set of values which will can be overriden if you pass parameters with the same name using `extra_grpc_config`.
|
||||
**Note**: When `GrpcWorkerAgentRuntime` creates a host connection for the clients, it uses `DEFAULT_GRPC_CONFIG` from `HostConnection` class as default set of values which will can be overriden if you pass parameters with the same name using `extra_grpc_config`.
|
||||
|
||||
@@ -117,7 +117,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -132,7 +132,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import SingleThreadedAgentRuntime\n",
|
||||
"\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"await MyAgent.register(runtime, \"my_agent\", lambda: MyAgent())"
|
||||
|
||||
@@ -28,18 +28,18 @@
|
||||
"```\n",
|
||||
"````\n",
|
||||
"\n",
|
||||
"We can start a host service using {py:class}`~autogen_core.application.WorkerAgentRuntimeHost`."
|
||||
"We can start a host service using {py:class}`~autogen_core.application.GrpcWorkerAgentRuntimeHost`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core.application import WorkerAgentRuntimeHost\n",
|
||||
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n",
|
||||
"\n",
|
||||
"host = WorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
|
||||
"host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
|
||||
"host.start() # Start a host service in the background."
|
||||
]
|
||||
},
|
||||
@@ -94,7 +94,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can set up the worker agent runtimes.\n",
|
||||
"We use {py:class}`~autogen_core.application.WorkerAgentRuntime`.\n",
|
||||
"We use {py:class}`~autogen_core.application.GrpcWorkerAgentRuntime`.\n",
|
||||
"We set up two worker runtimes. Each runtime hosts one agent.\n",
|
||||
"All agents publish and subscribe to the default topic, so they can see all\n",
|
||||
"messages being published.\n",
|
||||
@@ -104,7 +104,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -127,13 +127,13 @@
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"from autogen_core.application import WorkerAgentRuntime\n",
|
||||
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n",
|
||||
"\n",
|
||||
"worker1 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker1.start()\n",
|
||||
"await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n",
|
||||
"\n",
|
||||
"worker2 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
|
||||
"worker2.start()\n",
|
||||
"await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n",
|
||||
"\n",
|
||||
@@ -149,7 +149,7 @@
|
||||
"source": [
|
||||
"We can see each agent published exactly 5 messages.\n",
|
||||
"\n",
|
||||
"To stop the worker runtimes, we can call {py:meth}`~autogen_core.application.WorkerAgentRuntime.stop`."
|
||||
"To stop the worker runtimes, we can call {py:meth}`~autogen_core.application.GrpcWorkerAgentRuntime.stop`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -169,7 +169,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can call {py:meth}`~autogen_core.application.WorkerAgentRuntimeHost.stop`\n",
|
||||
"We can call {py:meth}`~autogen_core.application.GrpcWorkerAgentRuntimeHost.stop`\n",
|
||||
"to stop the host service."
|
||||
]
|
||||
},
|
||||
|
||||
@@ -90,8 +90,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MyAgent(RoutedAgent):\n",
|
||||
@@ -298,8 +297,7 @@
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"\n",
|
||||
"from autogen_core import MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,7 +48,7 @@ Now you can send the trace_provider when creating your runtime:
|
||||
# for single threaded runtime
|
||||
single_threaded_runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
|
||||
# or for worker runtime
|
||||
worker_runtime = WorkerAgentRuntime(tracer_provider=tracer_provider)
|
||||
worker_runtime = GrpcWorkerAgentRuntime(tracer_provider=tracer_provider)
|
||||
```
|
||||
|
||||
And that's it! Your application is now instrumented with open telemetry. You can now view your telemetry data in your telemetry backend.
|
||||
@@ -65,5 +65,5 @@ tracer_provider = trace.get_tracer_provider()
|
||||
# for single threaded runtime
|
||||
single_threaded_runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
|
||||
# or for worker runtime
|
||||
worker_runtime = WorkerAgentRuntime(tracer_provider=tracer_provider)
|
||||
worker_runtime = GrpcWorkerAgentRuntime(tracer_provider=tracer_provider)
|
||||
```
|
||||
|
||||
@@ -1,315 +1,321 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tools\n",
|
||||
"\n",
|
||||
"Tools are code that can be executed by an agent to perform actions. A tool\n",
|
||||
"can be a simple function such as a calculator, or an API call to a third-party service\n",
|
||||
"such as stock price lookup or weather forecast.\n",
|
||||
"In the context of AI agents, tools are designed to be executed by agents in\n",
|
||||
"response to model-generated function calls.\n",
|
||||
"\n",
|
||||
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
|
||||
"tools and utilities for creating and running custom tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Built-in Tools\n",
|
||||
"\n",
|
||||
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
|
||||
"which allows agents to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Here is how you create the tool and use it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello, world!\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"\n",
|
||||
"# Create the tool.\n",
|
||||
"code_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"await code_executor.start()\n",
|
||||
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"\n",
|
||||
"# Use the tool directly without an agent.\n",
|
||||
"code = \"print('Hello, world!')\"\n",
|
||||
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
|
||||
"print(code_execution_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
|
||||
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
|
||||
"in the local command line environment.\n",
|
||||
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
|
||||
"and provides a simple interface to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Other built-in tools will be added in the future."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom Function Tools\n",
|
||||
"\n",
|
||||
"A tool can also be a simple Python function that performs a specific action.\n",
|
||||
"To create a custom function tool, you just need to create a Python function\n",
|
||||
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
|
||||
"\n",
|
||||
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
|
||||
"to inform the LLM when and how to use a given function. The description provides context\n",
|
||||
"about the function’s purpose and intended use cases, while type annotations inform the LLM about\n",
|
||||
"the expected parameters and return type.\n",
|
||||
"\n",
|
||||
"For example, a simple tool to obtain the stock price of a company might look like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"80.44429939059668\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import FunctionTool\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
|
||||
" # Returns a random stock price for demonstration purposes.\n",
|
||||
" return random.uniform(10, 200)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a function tool.\n",
|
||||
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
|
||||
"\n",
|
||||
"# Run the tool.\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
|
||||
"\n",
|
||||
"# Print the result.\n",
|
||||
"print(stock_price_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool-Equipped Agent\n",
|
||||
"\n",
|
||||
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
|
||||
"by using it in a composition pattern.\n",
|
||||
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
|
||||
"as an inner agent for executing tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, AgentInstantiationContext, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
||||
" super().__init__(\"An agent with tools\")\n",
|
||||
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" # Create a session of messages.\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
|
||||
" # Run the caller loop to handle tool calls.\n",
|
||||
" messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Return the final response.\n",
|
||||
" assert isinstance(messages[-1].content, str)\n",
|
||||
" return Message(content=messages[-1].content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
|
||||
"to handle the interaction between the model and the tool agent.\n",
|
||||
"The core idea can be described using a simple control flow graph:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
|
||||
"and determines whether the model has generated a tool call.\n",
|
||||
"If the model has generated tool calls, then the handler sends a function call\n",
|
||||
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
|
||||
"to execute the tools,\n",
|
||||
"and then queries the model again with the results of the tool calls.\n",
|
||||
"This process continues until the model stops generating tool calls,\n",
|
||||
"at which point the final response is returned to the user.\n",
|
||||
"\n",
|
||||
"By having the tool execution logic in a separate agent,\n",
|
||||
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
|
||||
"can be observed externally and intercepted if necessary.\n",
|
||||
"\n",
|
||||
"To run the agent, we need to create a runtime and register the agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_use_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a runtime.\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"# Create the tools.\n",
|
||||
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
|
||||
"# Register the agents.\n",
|
||||
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_use_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
|
||||
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
|
||||
"Let's test the agent with a question about stock price."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start processing messages.\n",
|
||||
"runtime.start()\n",
|
||||
"# Send a direct message to the tool agent.\n",
|
||||
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
|
||||
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop processing messages.\n",
|
||||
"await runtime.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "autogen_core",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tools\n",
|
||||
"\n",
|
||||
"Tools are code that can be executed by an agent to perform actions. A tool\n",
|
||||
"can be a simple function such as a calculator, or an API call to a third-party service\n",
|
||||
"such as stock price lookup or weather forecast.\n",
|
||||
"In the context of AI agents, tools are designed to be executed by agents in\n",
|
||||
"response to model-generated function calls.\n",
|
||||
"\n",
|
||||
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
|
||||
"tools and utilities for creating and running custom tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Built-in Tools\n",
|
||||
"\n",
|
||||
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
|
||||
"which allows agents to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Here is how you create the tool and use it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello, world!\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"\n",
|
||||
"# Create the tool.\n",
|
||||
"code_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"await code_executor.start()\n",
|
||||
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"\n",
|
||||
"# Use the tool directly without an agent.\n",
|
||||
"code = \"print('Hello, world!')\"\n",
|
||||
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
|
||||
"print(code_execution_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
|
||||
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
|
||||
"in the local command line environment.\n",
|
||||
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
|
||||
"and provides a simple interface to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Other built-in tools will be added in the future."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom Function Tools\n",
|
||||
"\n",
|
||||
"A tool can also be a simple Python function that performs a specific action.\n",
|
||||
"To create a custom function tool, you just need to create a Python function\n",
|
||||
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
|
||||
"\n",
|
||||
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
|
||||
"to inform the LLM when and how to use a given function. The description provides context\n",
|
||||
"about the function’s purpose and intended use cases, while type annotations inform the LLM about\n",
|
||||
"the expected parameters and return type.\n",
|
||||
"\n",
|
||||
"For example, a simple tool to obtain the stock price of a company might look like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"80.44429939059668\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import FunctionTool\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
|
||||
" # Returns a random stock price for demonstration purposes.\n",
|
||||
" return random.uniform(10, 200)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a function tool.\n",
|
||||
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
|
||||
"\n",
|
||||
"# Run the tool.\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
|
||||
"\n",
|
||||
"# Print the result.\n",
|
||||
"print(stock_price_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool-Equipped Agent\n",
|
||||
"\n",
|
||||
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
|
||||
"by using it in a composition pattern.\n",
|
||||
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
|
||||
"as an inner agent for executing tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from autogen_core import (\n",
|
||||
" AgentId,\n",
|
||||
" AgentInstantiationContext,\n",
|
||||
" MessageContext,\n",
|
||||
" RoutedAgent,\n",
|
||||
" SingleThreadedAgentRuntime,\n",
|
||||
" message_handler,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
||||
" super().__init__(\"An agent with tools\")\n",
|
||||
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" # Create a session of messages.\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
|
||||
" # Run the caller loop to handle tool calls.\n",
|
||||
" messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Return the final response.\n",
|
||||
" assert isinstance(messages[-1].content, str)\n",
|
||||
" return Message(content=messages[-1].content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
|
||||
"to handle the interaction between the model and the tool agent.\n",
|
||||
"The core idea can be described using a simple control flow graph:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
|
||||
"and determines whether the model has generated a tool call.\n",
|
||||
"If the model has generated tool calls, then the handler sends a function call\n",
|
||||
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
|
||||
"to execute the tools,\n",
|
||||
"and then queries the model again with the results of the tool calls.\n",
|
||||
"This process continues until the model stops generating tool calls,\n",
|
||||
"at which point the final response is returned to the user.\n",
|
||||
"\n",
|
||||
"By having the tool execution logic in a separate agent,\n",
|
||||
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
|
||||
"can be observed externally and intercepted if necessary.\n",
|
||||
"\n",
|
||||
"To run the agent, we need to create a runtime and register the agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_use_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a runtime.\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"# Create the tools.\n",
|
||||
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
|
||||
"# Register the agents.\n",
|
||||
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_use_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
|
||||
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
|
||||
"Let's test the agent with a question about stock price."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start processing messages.\n",
|
||||
"runtime.start()\n",
|
||||
"# Send a direct message to the tool agent.\n",
|
||||
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
|
||||
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop processing messages.\n",
|
||||
"await runtime.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "autogen_core",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@@ -310,7 +310,7 @@
|
||||
"source": [
|
||||
"import tempfile\n",
|
||||
"\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
|
||||
@@ -25,7 +25,6 @@ dependencies = [
|
||||
"opentelemetry-api~=1.27.0",
|
||||
"asyncio_atexit",
|
||||
"jsonref~=1.1.0",
|
||||
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -35,6 +34,7 @@ grpc = [
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"autogen_test_utils",
|
||||
"aiofiles",
|
||||
"azure-identity",
|
||||
"chess",
|
||||
@@ -75,6 +75,8 @@ dev-dependencies = [
|
||||
"autodoc_pydantic~=2.2",
|
||||
"pygments",
|
||||
|
||||
"autogen_ext==0.4.0.dev8",
|
||||
|
||||
# Documentation tooling
|
||||
"sphinx-autobuild",
|
||||
]
|
||||
|
||||
@@ -7,8 +7,14 @@ import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime, DefaultSubscription, DefaultTopicId
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentRuntime,
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.components.models import SystemMessage
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Distributed Group Chat
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntimeHost
|
||||
|
||||
This example runs a gRPC server using [WorkerAgentRuntimeHost](../../src/autogen_core/application/_worker_runtime_host.py) and instantiates three distributed runtimes using [WorkerAgentRuntime](../../src/autogen_core/application/_worker_runtime.py). These runtimes connect to the gRPC server as hosts and facilitate a round-robin distributed group chat. This example leverages the [Azure OpenAI Service](https://azure.microsoft.com/en-us/products/ai-services/openai-service) to implement writer and editor LLM agents. Agents are instructed to provide concise answers, as the primary goal of this example is to showcase the distributed runtime rather than the quality of agent responses.
|
||||
This example runs a gRPC server using [GrpcWorkerAgentRuntimeHost](../../src/autogen_core/application/_worker_runtime_host.py) and instantiates three distributed runtimes using [GrpcWorkerAgentRuntime](../../src/autogen_core/application/_worker_runtime.py). These runtimes connect to the gRPC server as hosts and facilitate a round-robin distributed group chat. This example leverages the [Azure OpenAI Service](https://azure.microsoft.com/en-us/products/ai-services/openai-service) to implement writer and editor LLM agents. Agents are instructed to provide concise answers, as the primary goal of this example is to showcase the distributed runtime rather than the quality of agent responses.
|
||||
|
||||
## Setup
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from uuid import uuid4
|
||||
|
||||
from _types import GroupChatMessage, MessageChunk, RequestToSpeak, UIAgentConfig
|
||||
from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, message_handler
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@@ -13,6 +12,7 @@ from autogen_core.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
@@ -168,7 +168,7 @@ class UIAgent(RoutedAgent):
|
||||
|
||||
|
||||
async def publish_message_to_ui(
|
||||
runtime: RoutedAgent | WorkerAgentRuntime,
|
||||
runtime: RoutedAgent | GrpcWorkerAgentRuntime,
|
||||
source: str,
|
||||
user_message: str,
|
||||
ui_config: UIAgentConfig,
|
||||
@@ -193,7 +193,7 @@ async def publish_message_to_ui(
|
||||
|
||||
|
||||
async def publish_message_to_ui_and_backend(
|
||||
runtime: RoutedAgent | WorkerAgentRuntime,
|
||||
runtime: RoutedAgent | GrpcWorkerAgentRuntime,
|
||||
source: str,
|
||||
user_message: str,
|
||||
ui_config: UIAgentConfig,
|
||||
|
||||
@@ -8,15 +8,15 @@ from _utils import get_serializers, load_config, set_all_log_levels
|
||||
from autogen_core import (
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.models import AzureOpenAIChatCompletionClient
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
|
||||
async def main(config: AppConfig):
|
||||
set_all_log_levels(logging.ERROR)
|
||||
editor_agent_runtime = WorkerAgentRuntime(host_address=config.host.address)
|
||||
editor_agent_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
|
||||
editor_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
await asyncio.sleep(4)
|
||||
Console().print(Markdown("Starting **`Editor Agent`**"))
|
||||
|
||||
@@ -8,8 +8,8 @@ from _utils import get_serializers, load_config, set_all_log_levels
|
||||
from autogen_core import (
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.models import AzureOpenAIChatCompletionClient
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
@@ -18,7 +18,7 @@ set_all_log_levels(logging.ERROR)
|
||||
|
||||
async def main(config: AppConfig):
|
||||
set_all_log_levels(logging.ERROR)
|
||||
group_chat_manager_runtime = WorkerAgentRuntime(host_address=config.host.address)
|
||||
group_chat_manager_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
|
||||
|
||||
group_chat_manager_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@@ -2,13 +2,13 @@ import asyncio
|
||||
|
||||
from _types import HostConfig
|
||||
from _utils import load_config
|
||||
from autogen_core.application import WorkerAgentRuntimeHost
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
|
||||
async def main(host_config: HostConfig):
|
||||
host = WorkerAgentRuntimeHost(address=host_config.address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_config.address)
|
||||
host.start()
|
||||
|
||||
console = Console()
|
||||
|
||||
@@ -9,7 +9,7 @@ from _utils import get_serializers, load_config, set_all_log_levels
|
||||
from autogen_core import (
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
from chainlit import Message # type: ignore [reportAttributeAccessIssue]
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
@@ -36,7 +36,7 @@ async def send_cl_stream(msg: MessageChunk) -> None:
|
||||
|
||||
async def main(config: AppConfig):
|
||||
set_all_log_levels(logging.ERROR)
|
||||
ui_agent_runtime = WorkerAgentRuntime(host_address=config.host.address)
|
||||
ui_agent_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
|
||||
|
||||
ui_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@@ -8,15 +8,15 @@ from _utils import get_serializers, load_config, set_all_log_levels
|
||||
from autogen_core import (
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.models import AzureOpenAIChatCompletionClient
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
|
||||
async def main(config: AppConfig) -> None:
|
||||
set_all_log_levels(logging.ERROR)
|
||||
writer_agent_runtime = WorkerAgentRuntime(host_address=config.host.address)
|
||||
writer_agent_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
|
||||
writer_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
|
||||
await asyncio.sleep(3)
|
||||
Console().print(Markdown("Starting **`Writer Agent`**"))
|
||||
|
||||
@@ -2,12 +2,12 @@ import asyncio
|
||||
import logging
|
||||
import platform
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntimeHost
|
||||
from autogen_core.application.logging import TRACE_LOGGER_NAME
|
||||
from autogen_core import TRACE_LOGGER_NAME
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost
|
||||
|
||||
|
||||
async def run_host():
|
||||
host = WorkerAgentRuntimeHost(address="localhost:50051")
|
||||
host = GrpcWorkerAgentRuntimeHost(address="localhost:50051")
|
||||
host.start() # Start a host service in the background.
|
||||
if platform.system() == "Windows":
|
||||
try:
|
||||
|
||||
@@ -32,7 +32,7 @@ from _semantic_router_components import (
|
||||
WorkerAgentMessage,
|
||||
)
|
||||
from autogen_core import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId, MessageContext
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
|
||||
|
||||
class MockIntentClassifier(IntentClassifierBase):
|
||||
@@ -78,7 +78,7 @@ async def output_result(
|
||||
|
||||
|
||||
async def run_workers():
|
||||
agent_runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
agent_runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
|
||||
agent_runtime.start()
|
||||
|
||||
|
||||
@@ -37,10 +37,10 @@ from autogen_core import (
|
||||
FunctionCall,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
message_handler,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core.base.intervention import DefaultInterventionHandler
|
||||
from autogen_core.components.model_context import BufferedChatCompletionContext
|
||||
from autogen_core.components.models import (
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from agents import CascadingMessage, ObserverAgent
|
||||
from autogen_core import DefaultTopicId, try_get_known_serializers_for_type
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
|
||||
runtime.start()
|
||||
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())
|
||||
|
||||
@@ -2,11 +2,11 @@ import uuid
|
||||
|
||||
from agents import CascadingAgent, ReceiveMessageEvent
|
||||
from autogen_core import try_get_known_serializers_for_type
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
|
||||
runtime.start()
|
||||
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
|
||||
from autogen_core.application import WorkerAgentRuntimeHost
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
service = WorkerAgentRuntimeHost(address="localhost:50051")
|
||||
service = GrpcWorkerAgentRuntimeHost(address="localhost:50051")
|
||||
service.start()
|
||||
await service.stop_when_signal()
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from autogen_core import (
|
||||
message_handler,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -72,7 +72,7 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.start()
|
||||
for t in [AskToGreet, Greeting, ReturnedGreeting, Feedback, ReturnedFeedback]:
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(t))
|
||||
|
||||
@@ -10,7 +10,7 @@ from autogen_core import (
|
||||
RoutedAgent,
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,7 +53,7 @@ class GreeterAgent(RoutedAgent):
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = WorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
|
||||
runtime.start()
|
||||
|
||||
await ReceiveAgent.register(
|
||||
|
||||
@@ -12,7 +12,7 @@ from autogen_core import (
|
||||
TypeSubscription,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
|
||||
|
||||
# Add the local package directory to sys.path
|
||||
thisdir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -34,7 +34,7 @@ async def main() -> None:
|
||||
agentHost = agentHost[8:]
|
||||
agnext_logger.info("0")
|
||||
agnext_logger.info(agentHost)
|
||||
runtime = WorkerAgentRuntime(host_address=agentHost, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE)
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=agentHost, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE)
|
||||
|
||||
agnext_logger.info("1")
|
||||
runtime.start()
|
||||
|
||||
@@ -12,6 +12,7 @@ from ._agent_type import AgentType
|
||||
from ._base_agent import BaseAgent
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._closure_agent import ClosureAgent, ClosureContext
|
||||
from ._constants import EVENT_LOGGER_NAME, ROOT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
|
||||
from ._default_topic import DefaultTopicId
|
||||
from ._image import Image
|
||||
@@ -25,6 +26,7 @@ from ._serialization import (
|
||||
UnknownPayload,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._topic import TopicId
|
||||
@@ -66,4 +68,8 @@ __all__ = [
|
||||
"TypePrefixSubscription",
|
||||
"JSON_DATA_CONTENT_TYPE",
|
||||
"PROTOBUF_DATA_CONTENT_TYPE",
|
||||
"SingleThreadedAgentRuntime",
|
||||
"ROOT_LOGGER_NAME",
|
||||
"EVENT_LOGGER_NAME",
|
||||
"TRACE_LOGGER_NAME",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
ROOT_LOGGER_NAME = "autogen_core"
|
||||
"""str: Logger name used for structured event logging"""
|
||||
|
||||
EVENT_LOGGER_NAME = "autogen_core.events"
|
||||
"""str: Logger name used for structured event logging"""
|
||||
|
||||
|
||||
TRACE_LOGGER_NAME = "autogen_core.trace"
|
||||
"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
|
||||
@@ -1,11 +1,11 @@
|
||||
from collections import defaultdict
|
||||
from typing import Awaitable, Callable, DefaultDict, List, Set
|
||||
|
||||
from .._agent import Agent
|
||||
from .._agent_id import AgentId
|
||||
from .._agent_type import AgentType
|
||||
from .._subscription import Subscription
|
||||
from .._topic import TopicId
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_type import AgentType
|
||||
from ._subscription import Subscription
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
async def get_impl(
|
||||
@@ -0,0 +1,687 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from autogen_core._serialization import MessageSerializer, SerializationRegistry
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
from ._agent_metadata import AgentMetadata
|
||||
from ._agent_runtime import AgentRuntime
|
||||
from ._agent_type import AgentType
|
||||
from ._cancellation_token import CancellationToken
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._runtime_impl_helpers import SubscriptionManager, get_impl
|
||||
from ._subscription import Subscription
|
||||
from ._subscription_context import SubscriptionInstantiationContext
|
||||
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
|
||||
from ._topic import TopicId
|
||||
from .base.intervention import DropMessage, InterventionHandler
|
||||
from .exceptions import MessageDroppedException
|
||||
|
||||
logger = logging.getLogger("autogen_core")
|
||||
event_logger = logging.getLogger("autogen_core.events")
|
||||
|
||||
# We use a type parameter in some functions which shadows the built-in `type` function.
|
||||
# This is a workaround to avoid shadowing the built-in `type` function.
|
||||
type_func_alias = type
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PublishMessageEnvelope:
|
||||
"""A message envelope for publishing messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
cancellation_token: CancellationToken
|
||||
sender: AgentId | None
|
||||
topic_id: TopicId
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
message_id: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SendMessageEnvelope:
|
||||
"""A message envelope for sending a message to a specific agent that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
sender: AgentId | None
|
||||
recipient: AgentId
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseMessageEnvelope:
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: Any
|
||||
future: Future[Any]
|
||||
sender: AgentId
|
||||
recipient: AgentId | None
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self) -> None:
|
||||
self._count: int = 0
|
||||
self.threadLock = threading.Lock()
|
||||
|
||||
def increment(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count += 1
|
||||
self.threadLock.release()
|
||||
|
||||
def get(self) -> int:
|
||||
return self._count
|
||||
|
||||
def decrement(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count -= 1
|
||||
self.threadLock.release()
|
||||
|
||||
|
||||
class RunContext:
|
||||
class RunState(Enum):
|
||||
RUNNING = 0
|
||||
CANCELLED = 1
|
||||
UNTIL_IDLE = 2
|
||||
|
||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
self._run_state = RunContext.RunState.RUNNING
|
||||
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
|
||||
self._run_task = asyncio.create_task(self._run())
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._end_condition():
|
||||
return
|
||||
|
||||
await self._runtime.process_next()
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.CANCELLED
|
||||
self._end_condition = self._stop_when_cancelled
|
||||
await self._run_task
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||
self._end_condition = self._stop_when_idle
|
||||
await self._run_task
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
async with self._lock:
|
||||
self._end_condition = condition
|
||||
await self._run_task
|
||||
|
||||
def _stop_when_cancelled(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.CANCELLED
|
||||
|
||||
def _stop_when_idle(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
|
||||
|
||||
|
||||
def _warn_if_none(value: Any, handler_name: str) -> None:
|
||||
"""
|
||||
Utility function to check if the intervention handler returned None and issue a warning.
|
||||
|
||||
Args:
|
||||
value: The return value to check
|
||||
handler_name: Name of the intervention handler method for the warning message
|
||||
"""
|
||||
if value is None:
|
||||
warnings.warn(
|
||||
f"Intervention handler {handler_name} returned None. This might be unintentional. "
|
||||
"Consider returning the original message or DropMessage explicitly.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intervention_handlers: List[InterventionHandler] | None = None,
|
||||
tracer_provider: TracerProvider | None = None,
|
||||
) -> None:
|
||||
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._intervention_handlers = intervention_handlers
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._run_context: RunContext | None = None
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
self,
|
||||
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||
return self._message_queue
|
||||
|
||||
@property
|
||||
def outstanding_tasks(self) -> int:
|
||||
return self._outstanding_tasks.get()
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
return set(self._agent_factories.keys())
|
||||
|
||||
# Returns the response of the message
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Any:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
with self._tracer_helper.trace_block(
|
||||
"create",
|
||||
recipient,
|
||||
parent=None,
|
||||
extraAttributes={"message_type": type(message).__name__},
|
||||
):
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
if recipient.type not in self._known_agent_names:
|
||||
future.set_exception(Exception("Recipient not found"))
|
||||
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
|
||||
|
||||
self._message_queue.append(
|
||||
SendMessageEnvelope(
|
||||
message=message,
|
||||
recipient=recipient,
|
||||
future=future,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
)
|
||||
)
|
||||
|
||||
cancellation_token.link_future(future)
|
||||
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
with self._tracer_helper.trace_block(
|
||||
"create",
|
||||
topic_id,
|
||||
parent=None,
|
||||
extraAttributes={"message_type": type(message).__name__},
|
||||
):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}")
|
||||
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=None,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
self._message_queue.append(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
topic_id=topic_id,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent_id in self._instantiated_agents:
|
||||
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
|
||||
return state
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent_id_str in state:
|
||||
agent_id = AgentId.from_str(agent_id_str)
|
||||
if agent_id.type in self._known_agent_names:
|
||||
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
|
||||
recipient = message_envelope.recipient
|
||||
# todo: check if recipient is in the known namespaces
|
||||
# assert recipient in self._agents
|
||||
|
||||
try:
|
||||
# TODO use id
|
||||
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
recipient_agent = await self._get_agent(recipient)
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
# Will be fixed when send API removed
|
||||
message_id="NOT_DEFINED_TODO_FIX",
|
||||
)
|
||||
with MessageHandlerContext.populate_context(recipient_agent.id):
|
||||
response = await recipient_agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
except CancelledError as e:
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_exception(e)
|
||||
self._outstanding_tasks.decrement()
|
||||
return
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
self._outstanding_tasks.decrement()
|
||||
return
|
||||
|
||||
self._message_queue.append(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
sender=message_envelope.recipient,
|
||||
recipient=message_envelope.sender,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
)
|
||||
)
|
||||
self._outstanding_tasks.decrement()
|
||||
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
|
||||
try:
|
||||
responses: List[Awaitable[Any]] = []
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
|
||||
for agent_id in recipients:
|
||||
# Avoid sending the message back to the sender
|
||||
if message_envelope.sender is not None and agent_id == message_envelope.sender:
|
||||
continue
|
||||
|
||||
sender_agent = (
|
||||
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
||||
)
|
||||
sender_name = str(sender_agent.id) if sender_agent is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=agent,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=message_envelope.topic_id,
|
||||
is_rpc=False,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
message_id=message_envelope.message_id,
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
|
||||
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
|
||||
with self._tracer_helper.trace_block("process", agent.id, parent=None):
|
||||
with MessageHandlerContext.populate_context(agent.id):
|
||||
return await agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
|
||||
future = _on_message(agent, message_context)
|
||||
responses.append(future)
|
||||
|
||||
await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
# Ignore cancelled errors from logs
|
||||
if isinstance(e, CancelledError):
|
||||
return
|
||||
logger.error("Error processing publish message", exc_info=True)
|
||||
finally:
|
||||
self._outstanding_tasks.decrement()
|
||||
# TODO if responses are given for a publish
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata):
|
||||
content = (
|
||||
message_envelope.message.__dict__
|
||||
if hasattr(message_envelope.message, "__dict__")
|
||||
else message_envelope.message
|
||||
)
|
||||
logger.info(
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=message_envelope.recipient,
|
||||
# kind=MessageKind.RESPOND,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
self._outstanding_tasks.decrement()
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
async def process_next(self) -> None:
|
||||
"""Process the next message in the queue."""
|
||||
|
||||
if len(self._message_queue) == 0:
|
||||
# Yield control to the event loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
return
|
||||
message_envelope = self._message_queue.pop(0)
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
with self._tracer_helper.trace_block(
|
||||
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
|
||||
):
|
||||
try:
|
||||
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
|
||||
_warn_if_none(temp_message, "on_send")
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_send(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
with self._tracer_helper.trace_block(
|
||||
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
|
||||
):
|
||||
try:
|
||||
temp_message = await handler.on_publish(message, sender=sender)
|
||||
_warn_if_none(temp_message, "on_publish")
|
||||
except BaseException as e:
|
||||
# TODO: we should raise the intervention exception to the publisher.
|
||||
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
# TODO log message dropped
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_publish(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
try:
|
||||
temp_message = await handler.on_response(message, sender=sender, recipient=recipient)
|
||||
_warn_if_none(temp_message, "on_response")
|
||||
except BaseException as e:
|
||||
# TODO: should we raise the exception to sender of the response instead?
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_response(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the runtime message processing loop."""
|
||||
if self._run_context is not None:
|
||||
raise RuntimeError("Runtime is already started")
|
||||
self._run_context = RunContext(self)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the runtime message processing loop."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop()
|
||||
self._run_context = None
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
"""Stop the runtime message processing loop when there is
|
||||
no outstanding message being processed or queued."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when_idle()
|
||||
self._run_context = None
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
"""Stop the runtime message processing loop when the condition is met."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when(condition)
|
||||
self._run_context = None
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return (await self._get_agent(agent)).metadata
|
||||
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
return await (await self._get_agent(agent)).save_state()
|
||||
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
await (await self._get_agent(agent)).load_state(state)
|
||||
|
||||
@deprecated(
|
||||
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
|
||||
)
|
||||
async def register(
|
||||
self,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
| list[Subscription]
|
||||
| None = None,
|
||||
) -> AgentType:
|
||||
if type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
|
||||
if subscriptions is not None:
|
||||
if callable(subscriptions):
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions
|
||||
|
||||
for subscription in subscriptions_list:
|
||||
await self.add_subscription(subscription)
|
||||
|
||||
self._agent_factories[type] = agent_factory
|
||||
return AgentType(type)
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
*,
|
||||
type: AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
if type.type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
agent_instance = await maybe_agent_instance
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
return type
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
with AgentInstantiationContext.populate_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
return agent
|
||||
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if agent_id.type not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {agent_id.type} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.type]
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
if id.type not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {id.type} not found.")
|
||||
|
||||
# TODO: check if remote
|
||||
agent_instance = await self._get_agent(id)
|
||||
|
||||
if not isinstance(agent_instance, type):
|
||||
raise TypeError(
|
||||
f"Agent with name {id.type} is not of type {type.__name__}. It is of type {type_func_alias(agent_instance).__name__}"
|
||||
)
|
||||
|
||||
return agent_instance
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
await self._subscription_manager.remove_subscription(id)
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
) -> AgentId:
|
||||
return await get_impl(
|
||||
id_or_type=id_or_type,
|
||||
key=key,
|
||||
lazy=lazy,
|
||||
instance_getter=self._get_agent,
|
||||
)
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
self._serialization_registry.add_serializer(serializer)
|
||||
@@ -6,7 +6,8 @@ from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util import types
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from ... import AgentId, TopicId
|
||||
from .._agent_id import AgentId
|
||||
from .._topic import TopicId
|
||||
from ._constants import NAMESPACE
|
||||
|
||||
logger = logging.getLogger("autogen_core")
|
||||
@@ -1,9 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from types import NoneType, UnionType
|
||||
from typing import Any, Optional, Tuple, Type, Union, get_args, get_origin
|
||||
|
||||
# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
|
||||
ChannelArgumentType = Sequence[Tuple[str, Any]]
|
||||
from typing import Any, Optional, Type, Union, get_args, get_origin
|
||||
|
||||
|
||||
def is_union(t: object) -> bool:
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
The :mod:`autogen_core.application` module provides implementations of core components that are used to compose an application
|
||||
"""
|
||||
|
||||
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from ._worker_runtime import WorkerAgentRuntime
|
||||
from ._worker_runtime_host import WorkerAgentRuntimeHost
|
||||
|
||||
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "WorkerAgentRuntimeHost"]
|
||||
|
||||
@@ -1,689 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from autogen_core._serialization import MessageSerializer, SerializationRegistry
|
||||
from .._single_threaded_agent_runtime import SingleThreadedAgentRuntime as SingleThreadedAgentRuntimeAlias
|
||||
|
||||
from .. import (
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentMetadata,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
MessageHandlerContext,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
|
||||
@deprecated(
|
||||
"autogen_core.application.SingleThreadedAgentRuntime moved to autogen_core.SingleThreadedAgentRuntime. This alias will be removed in 0.4.0."
|
||||
)
|
||||
from ..base.intervention import DropMessage, InterventionHandler
|
||||
from ..exceptions import MessageDroppedException
|
||||
from ._helpers import SubscriptionManager, get_impl
|
||||
from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
|
||||
|
||||
logger = logging.getLogger("autogen_core")
|
||||
event_logger = logging.getLogger("autogen_core.events")
|
||||
|
||||
# We use a type parameter in some functions which shadows the built-in `type` function.
|
||||
# This is a workaround to avoid shadowing the built-in `type` function.
|
||||
type_func_alias = type
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PublishMessageEnvelope:
|
||||
"""A message envelope for publishing messages to all agents that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
cancellation_token: CancellationToken
|
||||
sender: AgentId | None
|
||||
topic_id: TopicId
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
message_id: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SendMessageEnvelope:
|
||||
"""A message envelope for sending a message to a specific agent that can handle
|
||||
the message of the type T."""
|
||||
|
||||
message: Any
|
||||
sender: AgentId | None
|
||||
recipient: AgentId
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseMessageEnvelope:
|
||||
"""A message envelope for sending a response to a message."""
|
||||
|
||||
message: Any
|
||||
future: Future[Any]
|
||||
sender: AgentId
|
||||
recipient: AgentId | None
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self) -> None:
|
||||
self._count: int = 0
|
||||
self.threadLock = threading.Lock()
|
||||
|
||||
def increment(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count += 1
|
||||
self.threadLock.release()
|
||||
|
||||
def get(self) -> int:
|
||||
return self._count
|
||||
|
||||
def decrement(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count -= 1
|
||||
self.threadLock.release()
|
||||
|
||||
|
||||
class RunContext:
|
||||
class RunState(Enum):
|
||||
RUNNING = 0
|
||||
CANCELLED = 1
|
||||
UNTIL_IDLE = 2
|
||||
|
||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
self._run_state = RunContext.RunState.RUNNING
|
||||
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
|
||||
self._run_task = asyncio.create_task(self._run())
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._end_condition():
|
||||
return
|
||||
|
||||
await self._runtime.process_next()
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.CANCELLED
|
||||
self._end_condition = self._stop_when_cancelled
|
||||
await self._run_task
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||
self._end_condition = self._stop_when_idle
|
||||
await self._run_task
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
async with self._lock:
|
||||
self._end_condition = condition
|
||||
await self._run_task
|
||||
|
||||
def _stop_when_cancelled(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.CANCELLED
|
||||
|
||||
def _stop_when_idle(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
|
||||
|
||||
|
||||
def _warn_if_none(value: Any, handler_name: str) -> None:
|
||||
"""
|
||||
Utility function to check if the intervention handler returned None and issue a warning.
|
||||
|
||||
Args:
|
||||
value: The return value to check
|
||||
handler_name: Name of the intervention handler method for the warning message
|
||||
"""
|
||||
if value is None:
|
||||
warnings.warn(
|
||||
f"Intervention handler {handler_name} returned None. This might be unintentional. "
|
||||
"Consider returning the original message or DropMessage explicitly.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
intervention_handlers: List[InterventionHandler] | None = None,
|
||||
tracer_provider: TracerProvider | None = None,
|
||||
) -> None:
|
||||
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._intervention_handlers = intervention_handlers
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._run_context: RunContext | None = None
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
self,
|
||||
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||
return self._message_queue
|
||||
|
||||
@property
|
||||
def outstanding_tasks(self) -> int:
|
||||
return self._outstanding_tasks.get()
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
return set(self._agent_factories.keys())
|
||||
|
||||
# Returns the response of the message
|
||||
async def send_message(
|
||||
self,
|
||||
message: Any,
|
||||
recipient: AgentId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> Any:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
with self._tracer_helper.trace_block(
|
||||
"create",
|
||||
recipient,
|
||||
parent=None,
|
||||
extraAttributes={"message_type": type(message).__name__},
|
||||
):
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
if recipient.type not in self._known_agent_names:
|
||||
future.set_exception(Exception("Recipient not found"))
|
||||
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
|
||||
|
||||
self._message_queue.append(
|
||||
SendMessageEnvelope(
|
||||
message=message,
|
||||
recipient=recipient,
|
||||
future=future,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
)
|
||||
)
|
||||
|
||||
cancellation_token.link_future(future)
|
||||
|
||||
return await future
|
||||
|
||||
async def publish_message(
|
||||
self,
|
||||
message: Any,
|
||||
topic_id: TopicId,
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
with self._tracer_helper.trace_block(
|
||||
"create",
|
||||
topic_id,
|
||||
parent=None,
|
||||
extraAttributes={"message_type": type(message).__name__},
|
||||
):
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}")
|
||||
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=None,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
|
||||
self._message_queue.append(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
topic_id=topic_id,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state: Dict[str, Dict[str, Any]] = {}
|
||||
for agent_id in self._instantiated_agents:
|
||||
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
|
||||
return state
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
for agent_id_str in state:
|
||||
agent_id = AgentId.from_str(agent_id_str)
|
||||
if agent_id.type in self._known_agent_names:
|
||||
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
|
||||
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
|
||||
recipient = message_envelope.recipient
|
||||
# todo: check if recipient is in the known namespaces
|
||||
# assert recipient in self._agents
|
||||
|
||||
try:
|
||||
# TODO use id
|
||||
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
recipient_agent = await self._get_agent(recipient)
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
# Will be fixed when send API removed
|
||||
message_id="NOT_DEFINED_TODO_FIX",
|
||||
)
|
||||
with MessageHandlerContext.populate_context(recipient_agent.id):
|
||||
response = await recipient_agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
except CancelledError as e:
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_exception(e)
|
||||
self._outstanding_tasks.decrement()
|
||||
return
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
self._outstanding_tasks.decrement()
|
||||
return
|
||||
|
||||
self._message_queue.append(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
sender=message_envelope.recipient,
|
||||
recipient=message_envelope.sender,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
)
|
||||
)
|
||||
self._outstanding_tasks.decrement()
|
||||
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
|
||||
try:
|
||||
responses: List[Awaitable[Any]] = []
|
||||
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
|
||||
for agent_id in recipients:
|
||||
# Avoid sending the message back to the sender
|
||||
if message_envelope.sender is not None and agent_id == message_envelope.sender:
|
||||
continue
|
||||
|
||||
sender_agent = (
|
||||
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
|
||||
)
|
||||
sender_name = str(sender_agent.id) if sender_agent is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=agent,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=message_envelope.topic_id,
|
||||
is_rpc=False,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
message_id=message_envelope.message_id,
|
||||
)
|
||||
agent = await self._get_agent(agent_id)
|
||||
|
||||
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
|
||||
with self._tracer_helper.trace_block("process", agent.id, parent=None):
|
||||
with MessageHandlerContext.populate_context(agent.id):
|
||||
return await agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
|
||||
future = _on_message(agent, message_context)
|
||||
responses.append(future)
|
||||
|
||||
await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
# Ignore cancelled errors from logs
|
||||
if isinstance(e, CancelledError):
|
||||
return
|
||||
logger.error("Error processing publish message", exc_info=True)
|
||||
finally:
|
||||
self._outstanding_tasks.decrement()
|
||||
# TODO if responses are given for a publish
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata):
|
||||
content = (
|
||||
message_envelope.message.__dict__
|
||||
if hasattr(message_envelope.message, "__dict__")
|
||||
else message_envelope.message
|
||||
)
|
||||
logger.info(
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=message_envelope.recipient,
|
||||
# kind=MessageKind.RESPOND,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
self._outstanding_tasks.decrement()
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
async def process_next(self) -> None:
|
||||
"""Process the next message in the queue."""
|
||||
|
||||
if len(self._message_queue) == 0:
|
||||
# Yield control to the event loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
return
|
||||
message_envelope = self._message_queue.pop(0)
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
with self._tracer_helper.trace_block(
|
||||
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
|
||||
):
|
||||
try:
|
||||
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
|
||||
_warn_if_none(temp_message, "on_send")
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_send(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
with self._tracer_helper.trace_block(
|
||||
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
|
||||
):
|
||||
try:
|
||||
temp_message = await handler.on_publish(message, sender=sender)
|
||||
_warn_if_none(temp_message, "on_publish")
|
||||
except BaseException as e:
|
||||
# TODO: we should raise the intervention exception to the publisher.
|
||||
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
# TODO log message dropped
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_publish(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
try:
|
||||
temp_message = await handler.on_response(message, sender=sender, recipient=recipient)
|
||||
_warn_if_none(temp_message, "on_response")
|
||||
except BaseException as e:
|
||||
# TODO: should we raise the exception to sender of the response instead?
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_response(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the runtime message processing loop."""
|
||||
if self._run_context is not None:
|
||||
raise RuntimeError("Runtime is already started")
|
||||
self._run_context = RunContext(self)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the runtime message processing loop."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop()
|
||||
self._run_context = None
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
"""Stop the runtime message processing loop when there is
|
||||
no outstanding message being processed or queued."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when_idle()
|
||||
self._run_context = None
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
"""Stop the runtime message processing loop when the condition is met."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when(condition)
|
||||
self._run_context = None
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return (await self._get_agent(agent)).metadata
|
||||
|
||||
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
|
||||
return await (await self._get_agent(agent)).save_state()
|
||||
|
||||
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
|
||||
await (await self._get_agent(agent)).load_state(state)
|
||||
|
||||
@deprecated(
|
||||
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
|
||||
)
|
||||
async def register(
|
||||
self,
|
||||
type: str,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
|
||||
| list[Subscription]
|
||||
| None = None,
|
||||
) -> AgentType:
|
||||
if type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
|
||||
if subscriptions is not None:
|
||||
if callable(subscriptions):
|
||||
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
|
||||
subscriptions_list_result = subscriptions()
|
||||
if inspect.isawaitable(subscriptions_list_result):
|
||||
subscriptions_list = await subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions_list_result
|
||||
else:
|
||||
subscriptions_list = subscriptions
|
||||
|
||||
for subscription in subscriptions_list:
|
||||
await self.add_subscription(subscription)
|
||||
|
||||
self._agent_factories[type] = agent_factory
|
||||
return AgentType(type)
|
||||
|
||||
async def register_factory(
|
||||
self,
|
||||
*,
|
||||
type: AgentType,
|
||||
agent_factory: Callable[[], T | Awaitable[T]],
|
||||
expected_class: type[T],
|
||||
) -> AgentType:
|
||||
if type.type in self._agent_factories:
|
||||
raise ValueError(f"Agent with type {type} already exists.")
|
||||
|
||||
async def factory_wrapper() -> T:
|
||||
maybe_agent_instance = agent_factory()
|
||||
if inspect.isawaitable(maybe_agent_instance):
|
||||
agent_instance = await maybe_agent_instance
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
self._agent_factories[type.type] = factory_wrapper
|
||||
|
||||
return type
|
||||
|
||||
async def _invoke_agent_factory(
|
||||
self,
|
||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
with AgentInstantiationContext.populate_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
return agent
|
||||
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
if agent_id in self._instantiated_agents:
|
||||
return self._instantiated_agents[agent_id]
|
||||
|
||||
if agent_id.type not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {agent_id.type} not found.")
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.type]
|
||||
agent = await self._invoke_agent_factory(agent_factory, agent_id)
|
||||
self._instantiated_agents[agent_id] = agent
|
||||
return agent
|
||||
|
||||
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
|
||||
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
|
||||
if id.type not in self._agent_factories:
|
||||
raise LookupError(f"Agent with name {id.type} not found.")
|
||||
|
||||
# TODO: check if remote
|
||||
agent_instance = await self._get_agent(id)
|
||||
|
||||
if not isinstance(agent_instance, type):
|
||||
raise TypeError(
|
||||
f"Agent with name {id.type} is not of type {type.__name__}. It is of type {type_func_alias(agent_instance).__name__}"
|
||||
)
|
||||
|
||||
return agent_instance
|
||||
|
||||
async def add_subscription(self, subscription: Subscription) -> None:
|
||||
await self._subscription_manager.add_subscription(subscription)
|
||||
|
||||
async def remove_subscription(self, id: str) -> None:
|
||||
await self._subscription_manager.remove_subscription(id)
|
||||
|
||||
async def get(
|
||||
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
|
||||
) -> AgentId:
|
||||
return await get_impl(
|
||||
id_or_type=id_or_type,
|
||||
key=key,
|
||||
lazy=lazy,
|
||||
instance_getter=self._get_agent,
|
||||
)
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
self._serialization_registry.add_serializer(serializer)
|
||||
class SingleThreadedAgentRuntime(SingleThreadedAgentRuntimeAlias):
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
ROOT_LOGGER_NAME = "autogen_core"
|
||||
"""str: Logger name used for structured event logging"""
|
||||
"""Deprecated alias. Use autogen_core.ROOT_LOGGER_NAME"""
|
||||
|
||||
EVENT_LOGGER_NAME = "autogen_core.events"
|
||||
"""str: Logger name used for structured event logging"""
|
||||
"""Deprecated alias. Use autogen_core.EVENT_LOGGER_NAME"""
|
||||
|
||||
|
||||
TRACE_LOGGER_NAME = "autogen_core.trace"
|
||||
"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
|
||||
|
||||
__all__ = [
|
||||
"ROOT_LOGGER_NAME",
|
||||
"EVENT_LOGGER_NAME",
|
||||
"TRACE_LOGGER_NAME",
|
||||
]
|
||||
"""Deprecated alias. Use autogen_core.TRACE_LOGGER_NAME"""
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
The :mod:`autogen_core.worker.protos` module provides Google Protobuf classes for agent-worker communication
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
@@ -1,75 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: agent_worker.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
import cloudevent_pb2 as cloudevent__pb2
|
||||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_worker_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\036Microsoft.AutoGen.Abstractions'
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._options = None
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._options = None
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_EVENT_METADATAENTRY']._options = None
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPICID']._serialized_start=75
|
||||
_globals['_TOPICID']._serialized_end=114
|
||||
_globals['_AGENTID']._serialized_start=116
|
||||
_globals['_AGENTID']._serialized_end=152
|
||||
_globals['_PAYLOAD']._serialized_start=154
|
||||
_globals['_PAYLOAD']._serialized_end=223
|
||||
_globals['_RPCREQUEST']._serialized_start=226
|
||||
_globals['_RPCREQUEST']._serialized_end=491
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=433
|
||||
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=480
|
||||
_globals['_RPCRESPONSE']._serialized_start=494
|
||||
_globals['_RPCRESPONSE']._serialized_end=678
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=433
|
||||
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=480
|
||||
_globals['_EVENT']._serialized_start=681
|
||||
_globals['_EVENT']._serialized_end=909
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_start=433
|
||||
_globals['_EVENT_METADATAENTRY']._serialized_end=480
|
||||
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911
|
||||
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971
|
||||
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973
|
||||
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_start=1069
|
||||
_globals['_TYPESUBSCRIPTION']._serialized_end=1127
|
||||
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=1129
|
||||
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=1200
|
||||
_globals['_SUBSCRIPTION']._serialized_start=1203
|
||||
_globals['_SUBSCRIPTION']._serialized_end=1353
|
||||
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1355
|
||||
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1443
|
||||
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1445
|
||||
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1537
|
||||
_globals['_AGENTSTATE']._serialized_start=1540
|
||||
_globals['_AGENTSTATE']._serialized_end=1697
|
||||
_globals['_GETSTATERESPONSE']._serialized_start=1699
|
||||
_globals['_GETSTATERESPONSE']._serialized_end=1805
|
||||
_globals['_SAVESTATERESPONSE']._serialized_start=1807
|
||||
_globals['_SAVESTATERESPONSE']._serialized_end=1873
|
||||
_globals['_MESSAGE']._serialized_start=1876
|
||||
_globals['_MESSAGE']._serialized_end=2298
|
||||
_globals['_AGENTRPC']._serialized_start=2301
|
||||
_globals['_AGENTRPC']._serialized_end=2479
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,132 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import agent_worker_pb2 as agent__worker__pb2
|
||||
|
||||
|
||||
class AgentRpcStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.OpenChannel = channel.stream_stream(
|
||||
'/agents.AgentRpc/OpenChannel',
|
||||
request_serializer=agent__worker__pb2.Message.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.Message.FromString,
|
||||
)
|
||||
self.GetState = channel.unary_unary(
|
||||
'/agents.AgentRpc/GetState',
|
||||
request_serializer=agent__worker__pb2.AgentId.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.GetStateResponse.FromString,
|
||||
)
|
||||
self.SaveState = channel.unary_unary(
|
||||
'/agents.AgentRpc/SaveState',
|
||||
request_serializer=agent__worker__pb2.AgentState.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.SaveStateResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class AgentRpcServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def OpenChannel(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetState(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SaveState(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AgentRpcServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'OpenChannel': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.OpenChannel,
|
||||
request_deserializer=agent__worker__pb2.Message.FromString,
|
||||
response_serializer=agent__worker__pb2.Message.SerializeToString,
|
||||
),
|
||||
'GetState': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetState,
|
||||
request_deserializer=agent__worker__pb2.AgentId.FromString,
|
||||
response_serializer=agent__worker__pb2.GetStateResponse.SerializeToString,
|
||||
),
|
||||
'SaveState': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SaveState,
|
||||
request_deserializer=agent__worker__pb2.AgentState.FromString,
|
||||
response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'agents.AgentRpc', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AgentRpc(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def OpenChannel(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_stream(request_iterator, target, '/agents.AgentRpc/OpenChannel',
|
||||
agent__worker__pb2.Message.SerializeToString,
|
||||
agent__worker__pb2.Message.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GetState(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/GetState',
|
||||
agent__worker__pb2.AgentId.SerializeToString,
|
||||
agent__worker__pb2.GetStateResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def SaveState(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/SaveState',
|
||||
agent__worker__pb2.AgentState.SerializeToString,
|
||||
agent__worker__pb2.SaveStateResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
@@ -1,39 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: cloudevent.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63loudevent.proto\x12\ncloudevent\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xa4\x05\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12:\n\nattributes\x18\x05 \x03(\x0b\x32&.cloudevent.CloudEvent.AttributesEntry\x12\x36\n\x08metadata\x18\x06 \x03(\x0b\x32$.cloudevent.CloudEvent.MetadataEntry\x12\x17\n\x0f\x64\x61tacontenttype\x18\x07 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x08 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\t \x01(\tH\x00\x12*\n\nproto_data\x18\n \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1a\x62\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12>\n\x05value\x18\x02 \x01(\x0b\x32/.cloudevent.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cloudevent_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'\252\002\036Microsoft.AutoGen.Abstractions'
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._options = None
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_CLOUDEVENT_METADATAENTRY']._options = None
|
||||
_globals['_CLOUDEVENT_METADATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_CLOUDEVENT']._serialized_start=93
|
||||
_globals['_CLOUDEVENT']._serialized_end=769
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_start=400
|
||||
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_end=498
|
||||
_globals['_CLOUDEVENT_METADATAENTRY']._serialized_start=500
|
||||
_globals['_CLOUDEVENT_METADATAENTRY']._serialized_end=547
|
||||
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_start=550
|
||||
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_end=761
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime
|
||||
from autogen_test_utils import NoopAgent
|
||||
from pytest_mock import MockerFixture
|
||||
from test_utils import NoopAgent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -8,9 +8,9 @@ from autogen_core import (
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2,8 +2,14 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from autogen_core import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId, MessageContext
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import (
|
||||
ClosureAgent,
|
||||
ClosureContext,
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
MessageContext,
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import pytest
|
||||
from autogen_core import AgentId
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import AgentId, SingleThreadedAgentRuntime
|
||||
from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage
|
||||
from autogen_core.exceptions import MessageDroppedException
|
||||
from test_utils import LoopbackAgent, MessageType
|
||||
from autogen_test_utils import LoopbackAgent, MessageType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -3,9 +3,18 @@ from dataclasses import dataclass
|
||||
from typing import Callable, cast
|
||||
|
||||
import pytest
|
||||
from autogen_core import AgentId, MessageContext, RoutedAgent, TopicId, TypeSubscription, event, message_handler, rpc
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from test_utils import LoopbackAgent
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
TopicId,
|
||||
TypeSubscription,
|
||||
event,
|
||||
message_handler,
|
||||
rpc,
|
||||
)
|
||||
from autogen_test_utils import LoopbackAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -6,14 +6,13 @@ from autogen_core import (
|
||||
AgentInstantiationContext,
|
||||
AgentType,
|
||||
DefaultTopicId,
|
||||
SingleThreadedAgentRuntime,
|
||||
TopicId,
|
||||
TypeSubscription,
|
||||
try_get_known_serializers_for_type,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from test_utils import (
|
||||
from autogen_test_utils import (
|
||||
CascadingAgent,
|
||||
CascadingMessageType,
|
||||
LoopbackAgent,
|
||||
@@ -21,7 +20,8 @@ from test_utils import (
|
||||
MessageType,
|
||||
NoopAgent,
|
||||
)
|
||||
from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
|
||||
from autogen_test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
test_exporter = TestExporter()
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Any, Mapping
|
||||
|
||||
import pytest
|
||||
from autogen_core import AgentId, BaseAgent, MessageContext
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import AgentId, BaseAgent, MessageContext, SingleThreadedAgentRuntime
|
||||
|
||||
|
||||
class StatefulAgent(BaseAgent):
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import pytest
|
||||
from autogen_core import AgentId, DefaultSubscription, DefaultTopicId, TopicId, TypeSubscription
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
DefaultSubscription,
|
||||
DefaultTopicId,
|
||||
SingleThreadedAgentRuntime,
|
||||
TopicId,
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core.exceptions import CantHandleException
|
||||
from test_utils import LoopbackAgent, MessageType
|
||||
from autogen_test_utils import LoopbackAgent, MessageType
|
||||
|
||||
|
||||
def test_type_subscription_match() -> None:
|
||||
|
||||
@@ -3,8 +3,7 @@ import json
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
|
||||
|
||||
import pytest
|
||||
from autogen_core import AgentId, CancellationToken, FunctionCall
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import AgentId, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
|
||||
@@ -41,21 +41,28 @@ video-surfer = [
|
||||
"openai-whisper",
|
||||
]
|
||||
|
||||
grpc = [
|
||||
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/autogen_ext"]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = []
|
||||
dev-dependencies = [
|
||||
"autogen_test_utils"
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
include = ["src/**", "tests/*.py"]
|
||||
exclude = ["src/autogen_ext/agents/web_surfer/*.js"]
|
||||
exclude = ["src/autogen_ext/agents/web_surfer/*.js", "src/autogen_ext/runtimes/grpc/protos", "tests/protos"]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["src", "tests"]
|
||||
exclude = ["src/autogen_ext/runtimes/grpc/protos", "tests/protos"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
@@ -66,6 +73,7 @@ include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
test = "pytest -n auto"
|
||||
mypy = "mypy --config-file ../../pyproject.toml --exclude src/autogen_ext/runtimes/grpc/protos --exclude tests/protos src tests"
|
||||
|
||||
[tool.mypy]
|
||||
[[tool.mypy.overrides]]
|
||||
|
||||
@@ -22,12 +22,12 @@ from typing import (
|
||||
|
||||
import tiktoken
|
||||
from autogen_core import (
|
||||
EVENT_LOGGER_NAME,
|
||||
TRACE_LOGGER_NAME,
|
||||
CancellationToken,
|
||||
FunctionCall,
|
||||
Image,
|
||||
)
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from autogen_core.application.logging.events import LLMCallEvent
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
@@ -42,6 +42,7 @@ from autogen_core.components.models import (
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import Tool, ToolSchema
|
||||
from autogen_core.logging import LLMCallEvent
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
from ._worker_runtime import GrpcWorkerAgentRuntime
|
||||
from ._worker_runtime_host import GrpcWorkerAgentRuntimeHost
|
||||
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
|
||||
|
||||
try:
|
||||
import grpc # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use the GRPC runtime the grpc extra must be installed. Run `pip install autogen-ext[grpc]`"
|
||||
) from e
|
||||
|
||||
__all__ = [
|
||||
"GrpcWorkerAgentRuntime",
|
||||
"GrpcWorkerAgentRuntimeHost",
|
||||
"GrpcWorkerAgentRuntimeHostServicer",
|
||||
]
|
||||
@@ -0,0 +1,4 @@
|
||||
from typing import Any, Sequence, Tuple
|
||||
|
||||
# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
|
||||
ChannelArgumentType = Sequence[Tuple[str, Any]]
|
||||
@@ -28,13 +28,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from google.protobuf import any_pb2
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from autogen_core.application.protos import cloudevent_pb2
|
||||
|
||||
from .. import (
|
||||
from autogen_core import (
|
||||
JSON_DATA_CONTENT_TYPE,
|
||||
PROTOBUF_DATA_CONTENT_TYPE,
|
||||
Agent,
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
@@ -44,24 +40,26 @@ from .. import (
|
||||
CancellationToken,
|
||||
MessageContext,
|
||||
MessageHandlerContext,
|
||||
MessageSerializer,
|
||||
Subscription,
|
||||
SubscriptionInstantiationContext,
|
||||
TopicId,
|
||||
TypePrefixSubscription,
|
||||
TypeSubscription,
|
||||
)
|
||||
from .._serialization import (
|
||||
JSON_DATA_CONTENT_TYPE,
|
||||
PROTOBUF_DATA_CONTENT_TYPE,
|
||||
MessageSerializer,
|
||||
from autogen_core._runtime_impl_helpers import SubscriptionManager, get_impl
|
||||
from autogen_core._serialization import (
|
||||
SerializationRegistry,
|
||||
)
|
||||
from .._type_helpers import ChannelArgumentType
|
||||
from .._type_prefix_subscription import TypePrefixSubscription
|
||||
from .._type_subscription import TypeSubscription
|
||||
from autogen_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
|
||||
from google.protobuf import any_pb2
|
||||
from opentelemetry.trace import TracerProvider
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from . import _constants
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._helpers import SubscriptionManager, get_impl
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
||||
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
|
||||
from ._type_helpers import ChannelArgumentType
|
||||
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
|
||||
|
||||
try:
|
||||
import grpc.aio
|
||||
@@ -181,7 +179,7 @@ class HostConnection:
|
||||
return await self._recv_queue.get()
|
||||
|
||||
|
||||
class WorkerAgentRuntime(AgentRuntime):
|
||||
class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
@@ -3,9 +3,9 @@ import logging
|
||||
import signal
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from .._type_helpers import ChannelArgumentType
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
|
||||
from ._type_helpers import ChannelArgumentType
|
||||
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
|
||||
|
||||
try:
|
||||
import grpc
|
||||
@@ -16,10 +16,10 @@ from .protos import agent_worker_pb2_grpc
|
||||
logger = logging.getLogger("autogen_core")
|
||||
|
||||
|
||||
class WorkerAgentRuntimeHost:
|
||||
class GrpcWorkerAgentRuntimeHost:
|
||||
def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None:
|
||||
self._server = grpc.aio.server(options=extra_grpc_config)
|
||||
self._servicer = WorkerAgentRuntimeHostServicer()
|
||||
self._servicer = GrpcWorkerAgentRuntimeHostServicer()
|
||||
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
|
||||
self._server.add_insecure_port(address)
|
||||
self._address = address
|
||||
@@ -4,10 +4,10 @@ from _collections_abc import AsyncIterator, Iterator
|
||||
from asyncio import Future, Task
|
||||
from typing import Any, Dict, Set, cast
|
||||
|
||||
from .. import Subscription, TopicId, TypeSubscription
|
||||
from .._type_prefix_subscription import TypePrefixSubscription
|
||||
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
|
||||
from autogen_core._runtime_impl_helpers import SubscriptionManager
|
||||
|
||||
from ._constants import GRPC_IMPORT_ERROR_STR
|
||||
from ._helpers import SubscriptionManager
|
||||
|
||||
try:
|
||||
import grpc
|
||||
@@ -20,7 +20,7 @@ logger = logging.getLogger("autogen_core")
|
||||
event_logger = logging.getLogger("autogen_core.events")
|
||||
|
||||
|
||||
class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
|
||||
"""A gRPC servicer that hosts message delivery service for agents."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
The :mod:`autogen_ext.runtimes.grpc.protos` module provides Google Protobuf classes for agent-worker communication
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
@@ -0,0 +1,78 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: agent_worker.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
import cloudevent_pb2 as cloudevent__pb2
|
||||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3'
|
||||
)
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "agent_worker_pb2", _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals["DESCRIPTOR"]._options = None
|
||||
_globals["DESCRIPTOR"]._serialized_options = b"\252\002\036Microsoft.AutoGen.Abstractions"
|
||||
_globals["_RPCREQUEST_METADATAENTRY"]._options = None
|
||||
_globals["_RPCREQUEST_METADATAENTRY"]._serialized_options = b"8\001"
|
||||
_globals["_RPCRESPONSE_METADATAENTRY"]._options = None
|
||||
_globals["_RPCRESPONSE_METADATAENTRY"]._serialized_options = b"8\001"
|
||||
_globals["_EVENT_METADATAENTRY"]._options = None
|
||||
_globals["_EVENT_METADATAENTRY"]._serialized_options = b"8\001"
|
||||
_globals["_TOPICID"]._serialized_start = 75
|
||||
_globals["_TOPICID"]._serialized_end = 114
|
||||
_globals["_AGENTID"]._serialized_start = 116
|
||||
_globals["_AGENTID"]._serialized_end = 152
|
||||
_globals["_PAYLOAD"]._serialized_start = 154
|
||||
_globals["_PAYLOAD"]._serialized_end = 223
|
||||
_globals["_RPCREQUEST"]._serialized_start = 226
|
||||
_globals["_RPCREQUEST"]._serialized_end = 491
|
||||
_globals["_RPCREQUEST_METADATAENTRY"]._serialized_start = 433
|
||||
_globals["_RPCREQUEST_METADATAENTRY"]._serialized_end = 480
|
||||
_globals["_RPCRESPONSE"]._serialized_start = 494
|
||||
_globals["_RPCRESPONSE"]._serialized_end = 678
|
||||
_globals["_RPCRESPONSE_METADATAENTRY"]._serialized_start = 433
|
||||
_globals["_RPCRESPONSE_METADATAENTRY"]._serialized_end = 480
|
||||
_globals["_EVENT"]._serialized_start = 681
|
||||
_globals["_EVENT"]._serialized_end = 909
|
||||
_globals["_EVENT_METADATAENTRY"]._serialized_start = 433
|
||||
_globals["_EVENT_METADATAENTRY"]._serialized_end = 480
|
||||
_globals["_REGISTERAGENTTYPEREQUEST"]._serialized_start = 911
|
||||
_globals["_REGISTERAGENTTYPEREQUEST"]._serialized_end = 971
|
||||
_globals["_REGISTERAGENTTYPERESPONSE"]._serialized_start = 973
|
||||
_globals["_REGISTERAGENTTYPERESPONSE"]._serialized_end = 1067
|
||||
_globals["_TYPESUBSCRIPTION"]._serialized_start = 1069
|
||||
_globals["_TYPESUBSCRIPTION"]._serialized_end = 1127
|
||||
_globals["_TYPEPREFIXSUBSCRIPTION"]._serialized_start = 1129
|
||||
_globals["_TYPEPREFIXSUBSCRIPTION"]._serialized_end = 1200
|
||||
_globals["_SUBSCRIPTION"]._serialized_start = 1203
|
||||
_globals["_SUBSCRIPTION"]._serialized_end = 1353
|
||||
_globals["_ADDSUBSCRIPTIONREQUEST"]._serialized_start = 1355
|
||||
_globals["_ADDSUBSCRIPTIONREQUEST"]._serialized_end = 1443
|
||||
_globals["_ADDSUBSCRIPTIONRESPONSE"]._serialized_start = 1445
|
||||
_globals["_ADDSUBSCRIPTIONRESPONSE"]._serialized_end = 1537
|
||||
_globals["_AGENTSTATE"]._serialized_start = 1540
|
||||
_globals["_AGENTSTATE"]._serialized_end = 1697
|
||||
_globals["_GETSTATERESPONSE"]._serialized_start = 1699
|
||||
_globals["_GETSTATERESPONSE"]._serialized_end = 1805
|
||||
_globals["_SAVESTATERESPONSE"]._serialized_start = 1807
|
||||
_globals["_SAVESTATERESPONSE"]._serialized_end = 1873
|
||||
_globals["_MESSAGE"]._serialized_start = 1876
|
||||
_globals["_MESSAGE"]._serialized_end = 2298
|
||||
_globals["_AGENTRPC"]._serialized_start = 2301
|
||||
_globals["_AGENTRPC"]._serialized_end = 2479
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -4,13 +4,14 @@ isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import cloudevent_pb2
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import cloudevent_pb2
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@@ -67,7 +68,12 @@ class Payload(google.protobuf.message.Message):
|
||||
data_content_type: builtins.str = ...,
|
||||
data: builtins.bytes = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"
|
||||
],
|
||||
) -> None: ...
|
||||
|
||||
global___Payload = Payload
|
||||
|
||||
@@ -117,8 +123,31 @@ class RpcRequest(google.protobuf.message.Message):
|
||||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"_source",
|
||||
b"_source",
|
||||
"metadata",
|
||||
b"metadata",
|
||||
"method",
|
||||
b"method",
|
||||
"payload",
|
||||
b"payload",
|
||||
"request_id",
|
||||
b"request_id",
|
||||
"source",
|
||||
b"source",
|
||||
"target",
|
||||
b"target",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
|
||||
|
||||
global___RpcRequest = RpcRequest
|
||||
@@ -162,7 +191,12 @@ class RpcResponse(google.protobuf.message.Message):
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"
|
||||
],
|
||||
) -> None: ...
|
||||
|
||||
global___RpcResponse = RpcResponse
|
||||
|
||||
@@ -208,8 +242,26 @@ class Event(google.protobuf.message.Message):
|
||||
payload: global___Payload | None = ...,
|
||||
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
|
||||
def HasField(
|
||||
self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"_source",
|
||||
b"_source",
|
||||
"metadata",
|
||||
b"metadata",
|
||||
"payload",
|
||||
b"payload",
|
||||
"source",
|
||||
b"source",
|
||||
"topic_source",
|
||||
b"topic_source",
|
||||
"topic_type",
|
||||
b"topic_type",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
|
||||
|
||||
global___Event = Event
|
||||
@@ -250,7 +302,12 @@ class RegisterAgentTypeResponse(google.protobuf.message.Message):
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___RegisterAgentTypeResponse = RegisterAgentTypeResponse
|
||||
@@ -269,7 +326,9 @@ class TypeSubscription(google.protobuf.message.Message):
|
||||
topic_type: builtins.str = ...,
|
||||
agent_type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ...
|
||||
def ClearField(
|
||||
self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]
|
||||
) -> None: ...
|
||||
|
||||
global___TypeSubscription = TypeSubscription
|
||||
|
||||
@@ -287,7 +346,9 @@ class TypePrefixSubscription(google.protobuf.message.Message):
|
||||
topic_type_prefix: builtins.str = ...,
|
||||
agent_type: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]) -> None: ...
|
||||
def ClearField(
|
||||
self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]
|
||||
) -> None: ...
|
||||
|
||||
global___TypePrefixSubscription = TypePrefixSubscription
|
||||
|
||||
@@ -307,9 +368,31 @@ class Subscription(google.protobuf.message.Message):
|
||||
typeSubscription: global___TypeSubscription | None = ...,
|
||||
typePrefixSubscription: global___TypePrefixSubscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"subscription",
|
||||
b"subscription",
|
||||
"typePrefixSubscription",
|
||||
b"typePrefixSubscription",
|
||||
"typeSubscription",
|
||||
b"typeSubscription",
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"subscription",
|
||||
b"subscription",
|
||||
"typePrefixSubscription",
|
||||
b"typePrefixSubscription",
|
||||
"typeSubscription",
|
||||
b"typeSubscription",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing.Literal["subscription", b"subscription"]
|
||||
) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
|
||||
|
||||
global___Subscription = Subscription
|
||||
|
||||
@@ -329,7 +412,9 @@ class AddSubscriptionRequest(google.protobuf.message.Message):
|
||||
subscription: global___Subscription | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["request_id", b"request_id", "subscription", b"subscription"]) -> None: ...
|
||||
def ClearField(
|
||||
self, field_name: typing.Literal["request_id", b"request_id", "subscription", b"subscription"]
|
||||
) -> None: ...
|
||||
|
||||
global___AddSubscriptionRequest = AddSubscriptionRequest
|
||||
|
||||
@@ -351,7 +436,12 @@ class AddSubscriptionResponse(google.protobuf.message.Message):
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___AddSubscriptionResponse = AddSubscriptionResponse
|
||||
@@ -381,9 +471,41 @@ class AgentState(google.protobuf.message.Message):
|
||||
text_data: builtins.str = ...,
|
||||
proto_data: google.protobuf.any_pb2.Any | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "eTag", b"eTag", "proto_data", b"proto_data", "text_data", b"text_data"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"agent_id",
|
||||
b"agent_id",
|
||||
"binary_data",
|
||||
b"binary_data",
|
||||
"data",
|
||||
b"data",
|
||||
"proto_data",
|
||||
b"proto_data",
|
||||
"text_data",
|
||||
b"text_data",
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"agent_id",
|
||||
b"agent_id",
|
||||
"binary_data",
|
||||
b"binary_data",
|
||||
"data",
|
||||
b"data",
|
||||
"eTag",
|
||||
b"eTag",
|
||||
"proto_data",
|
||||
b"proto_data",
|
||||
"text_data",
|
||||
b"text_data",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing.Literal["data", b"data"]
|
||||
) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
|
||||
|
||||
global___AgentState = AgentState
|
||||
|
||||
@@ -405,8 +527,15 @@ class GetStateResponse(google.protobuf.message.Message):
|
||||
success: builtins.bool = ...,
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "agent_state", b"agent_state", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "agent_state", b"agent_state", "error", b"error", "success", b"success"]) -> None: ...
|
||||
def HasField(
|
||||
self, field_name: typing.Literal["_error", b"_error", "agent_state", b"agent_state", "error", b"error"]
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"_error", b"_error", "agent_state", b"agent_state", "error", b"error", "success", b"success"
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___GetStateResponse = GetStateResponse
|
||||
@@ -426,7 +555,9 @@ class SaveStateResponse(google.protobuf.message.Message):
|
||||
error: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "success", b"success"]) -> None: ...
|
||||
def ClearField(
|
||||
self, field_name: typing.Literal["_error", b"_error", "error", b"error", "success", b"success"]
|
||||
) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
|
||||
|
||||
global___SaveStateResponse = SaveStateResponse
|
||||
@@ -467,8 +598,61 @@ class Message(google.protobuf.message.Message):
|
||||
addSubscriptionRequest: global___AddSubscriptionRequest | None = ...,
|
||||
addSubscriptionResponse: global___AddSubscriptionResponse | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"addSubscriptionRequest",
|
||||
b"addSubscriptionRequest",
|
||||
"addSubscriptionResponse",
|
||||
b"addSubscriptionResponse",
|
||||
"cloudEvent",
|
||||
b"cloudEvent",
|
||||
"message",
|
||||
b"message",
|
||||
"registerAgentTypeRequest",
|
||||
b"registerAgentTypeRequest",
|
||||
"registerAgentTypeResponse",
|
||||
b"registerAgentTypeResponse",
|
||||
"request",
|
||||
b"request",
|
||||
"response",
|
||||
b"response",
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"addSubscriptionRequest",
|
||||
b"addSubscriptionRequest",
|
||||
"addSubscriptionResponse",
|
||||
b"addSubscriptionResponse",
|
||||
"cloudEvent",
|
||||
b"cloudEvent",
|
||||
"message",
|
||||
b"message",
|
||||
"registerAgentTypeRequest",
|
||||
b"registerAgentTypeRequest",
|
||||
"registerAgentTypeResponse",
|
||||
b"registerAgentTypeResponse",
|
||||
"request",
|
||||
b"request",
|
||||
"response",
|
||||
b"response",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing.Literal["message", b"message"]
|
||||
) -> (
|
||||
typing.Literal[
|
||||
"request",
|
||||
"response",
|
||||
"cloudEvent",
|
||||
"registerAgentTypeRequest",
|
||||
"registerAgentTypeResponse",
|
||||
"addSubscriptionRequest",
|
||||
"addSubscriptionResponse",
|
||||
]
|
||||
| None
|
||||
): ...
|
||||
|
||||
global___Message = Message
|
||||
@@ -0,0 +1,167 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
|
||||
import agent_worker_pb2 as agent__worker__pb2
|
||||
import grpc
|
||||
|
||||
|
||||
class AgentRpcStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.OpenChannel = channel.stream_stream(
|
||||
"/agents.AgentRpc/OpenChannel",
|
||||
request_serializer=agent__worker__pb2.Message.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.Message.FromString,
|
||||
)
|
||||
self.GetState = channel.unary_unary(
|
||||
"/agents.AgentRpc/GetState",
|
||||
request_serializer=agent__worker__pb2.AgentId.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.GetStateResponse.FromString,
|
||||
)
|
||||
self.SaveState = channel.unary_unary(
|
||||
"/agents.AgentRpc/SaveState",
|
||||
request_serializer=agent__worker__pb2.AgentState.SerializeToString,
|
||||
response_deserializer=agent__worker__pb2.SaveStateResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class AgentRpcServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def OpenChannel(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Method not implemented!")
|
||||
raise NotImplementedError("Method not implemented!")
|
||||
|
||||
def GetState(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Method not implemented!")
|
||||
raise NotImplementedError("Method not implemented!")
|
||||
|
||||
def SaveState(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Method not implemented!")
|
||||
raise NotImplementedError("Method not implemented!")
|
||||
|
||||
|
||||
def add_AgentRpcServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
"OpenChannel": grpc.stream_stream_rpc_method_handler(
|
||||
servicer.OpenChannel,
|
||||
request_deserializer=agent__worker__pb2.Message.FromString,
|
||||
response_serializer=agent__worker__pb2.Message.SerializeToString,
|
||||
),
|
||||
"GetState": grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetState,
|
||||
request_deserializer=agent__worker__pb2.AgentId.FromString,
|
||||
response_serializer=agent__worker__pb2.GetStateResponse.SerializeToString,
|
||||
),
|
||||
"SaveState": grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SaveState,
|
||||
request_deserializer=agent__worker__pb2.AgentState.FromString,
|
||||
response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler("agents.AgentRpc", rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AgentRpc(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def OpenChannel(
|
||||
request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
):
|
||||
return grpc.experimental.stream_stream(
|
||||
request_iterator,
|
||||
target,
|
||||
"/agents.AgentRpc/OpenChannel",
|
||||
agent__worker__pb2.Message.SerializeToString,
|
||||
agent__worker__pb2.Message.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def GetState(
|
||||
request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
"/agents.AgentRpc/GetState",
|
||||
agent__worker__pb2.AgentId.SerializeToString,
|
||||
agent__worker__pb2.GetStateResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def SaveState(
|
||||
request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
"/agents.AgentRpc/SaveState",
|
||||
agent__worker__pb2.AgentState.SerializeToString,
|
||||
agent__worker__pb2.SaveStateResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
)
|
||||
@@ -4,16 +4,16 @@ isort:skip_file
|
||||
"""
|
||||
|
||||
import abc
|
||||
import agent_worker_pb2
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import agent_worker_pb2
|
||||
import grpc
|
||||
import grpc.aio
|
||||
import typing
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
|
||||
|
||||
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
|
||||
...
|
||||
|
||||
@@ -56,20 +56,26 @@ class AgentRpcServicer(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
request_iterator: _MaybeAsyncIterator[agent_worker_pb2.Message],
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ...
|
||||
|
||||
) -> typing.Union[
|
||||
collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]
|
||||
]: ...
|
||||
@abc.abstractmethod
|
||||
def GetState(
|
||||
self,
|
||||
request: agent_worker_pb2.AgentId,
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[agent_worker_pb2.GetStateResponse, collections.abc.Awaitable[agent_worker_pb2.GetStateResponse]]: ...
|
||||
|
||||
) -> typing.Union[
|
||||
agent_worker_pb2.GetStateResponse, collections.abc.Awaitable[agent_worker_pb2.GetStateResponse]
|
||||
]: ...
|
||||
@abc.abstractmethod
|
||||
def SaveState(
|
||||
self,
|
||||
request: agent_worker_pb2.AgentState,
|
||||
context: _ServicerContext,
|
||||
) -> typing.Union[agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]]: ...
|
||||
) -> typing.Union[
|
||||
agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]
|
||||
]: ...
|
||||
|
||||
def add_AgentRpcServicer_to_server(servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...
|
||||
def add_AgentRpcServicer_to_server(
|
||||
servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]
|
||||
) -> None: ...
|
||||
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: cloudevent.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
||||
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x10\x63loudevent.proto\x12\ncloudevent\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xa4\x05\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12:\n\nattributes\x18\x05 \x03(\x0b\x32&.cloudevent.CloudEvent.AttributesEntry\x12\x36\n\x08metadata\x18\x06 \x03(\x0b\x32$.cloudevent.CloudEvent.MetadataEntry\x12\x17\n\x0f\x64\x61tacontenttype\x18\x07 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x08 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\t \x01(\tH\x00\x12*\n\nproto_data\x18\n \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1a\x62\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12>\n\x05value\x18\x02 \x01(\x0b\x32/.cloudevent.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3'
|
||||
)
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "cloudevent_pb2", _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals["DESCRIPTOR"]._options = None
|
||||
_globals["DESCRIPTOR"]._serialized_options = b"\252\002\036Microsoft.AutoGen.Abstractions"
|
||||
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._options = None
|
||||
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._serialized_options = b"8\001"
|
||||
_globals["_CLOUDEVENT_METADATAENTRY"]._options = None
|
||||
_globals["_CLOUDEVENT_METADATAENTRY"]._serialized_options = b"8\001"
|
||||
_globals["_CLOUDEVENT"]._serialized_start = 93
|
||||
_globals["_CLOUDEVENT"]._serialized_end = 769
|
||||
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._serialized_start = 400
|
||||
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._serialized_end = 498
|
||||
_globals["_CLOUDEVENT_METADATAENTRY"]._serialized_start = 500
|
||||
_globals["_CLOUDEVENT_METADATAENTRY"]._serialized_end = 547
|
||||
_globals["_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE"]._serialized_start = 550
|
||||
_globals["_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE"]._serialized_end = 761
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -5,12 +5,13 @@ isort:skip_file
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import google.protobuf.any_pb2
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.message
|
||||
import google.protobuf.timestamp_pb2
|
||||
import typing
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@@ -87,9 +88,54 @@ class CloudEvent(google.protobuf.message.Message):
|
||||
ce_uri_ref: builtins.str = ...,
|
||||
ce_timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["attr", b"attr"]) -> typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"] | None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"attr",
|
||||
b"attr",
|
||||
"ce_boolean",
|
||||
b"ce_boolean",
|
||||
"ce_bytes",
|
||||
b"ce_bytes",
|
||||
"ce_integer",
|
||||
b"ce_integer",
|
||||
"ce_string",
|
||||
b"ce_string",
|
||||
"ce_timestamp",
|
||||
b"ce_timestamp",
|
||||
"ce_uri",
|
||||
b"ce_uri",
|
||||
"ce_uri_ref",
|
||||
b"ce_uri_ref",
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"attr",
|
||||
b"attr",
|
||||
"ce_boolean",
|
||||
b"ce_boolean",
|
||||
"ce_bytes",
|
||||
b"ce_bytes",
|
||||
"ce_integer",
|
||||
b"ce_integer",
|
||||
"ce_string",
|
||||
b"ce_string",
|
||||
"ce_timestamp",
|
||||
b"ce_timestamp",
|
||||
"ce_uri",
|
||||
b"ce_uri",
|
||||
"ce_uri_ref",
|
||||
b"ce_uri_ref",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing.Literal["attr", b"attr"]
|
||||
) -> (
|
||||
typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"]
|
||||
| None
|
||||
): ...
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
SOURCE_FIELD_NUMBER: builtins.int
|
||||
@@ -115,7 +161,9 @@ class CloudEvent(google.protobuf.message.Message):
|
||||
binary_data: builtins.bytes
|
||||
text_data: builtins.str
|
||||
@property
|
||||
def attributes(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]:
|
||||
def attributes(
|
||||
self,
|
||||
) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]:
|
||||
"""Optional & Extension Attributes"""
|
||||
|
||||
@property
|
||||
@@ -136,8 +184,41 @@ class CloudEvent(google.protobuf.message.Message):
|
||||
text_data: builtins.str = ...,
|
||||
proto_data: google.protobuf.any_pb2.Any | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["attributes", b"attributes", "binary_data", b"binary_data", "data", b"data", "datacontenttype", b"datacontenttype", "id", b"id", "metadata", b"metadata", "proto_data", b"proto_data", "source", b"source", "spec_version", b"spec_version", "text_data", b"text_data", "type", b"type"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
|
||||
def HasField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"
|
||||
],
|
||||
) -> builtins.bool: ...
|
||||
def ClearField(
|
||||
self,
|
||||
field_name: typing.Literal[
|
||||
"attributes",
|
||||
b"attributes",
|
||||
"binary_data",
|
||||
b"binary_data",
|
||||
"data",
|
||||
b"data",
|
||||
"datacontenttype",
|
||||
b"datacontenttype",
|
||||
"id",
|
||||
b"id",
|
||||
"metadata",
|
||||
b"metadata",
|
||||
"proto_data",
|
||||
b"proto_data",
|
||||
"source",
|
||||
b"source",
|
||||
"spec_version",
|
||||
b"spec_version",
|
||||
"text_data",
|
||||
b"text_data",
|
||||
"type",
|
||||
b"type",
|
||||
],
|
||||
) -> None: ...
|
||||
def WhichOneof(
|
||||
self, oneof_group: typing.Literal["data", b"data"]
|
||||
) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
|
||||
|
||||
global___CloudEvent = CloudEvent
|
||||
@@ -0,0 +1,4 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
|
||||
import grpc
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import abc
|
||||
import collections.abc
|
||||
import typing
|
||||
|
||||
import grpc
|
||||
import grpc.aio
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
|
||||
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
|
||||
...
|
||||
@@ -3,8 +3,15 @@ from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_core import AgentId, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
DefaultTopicId,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
default_subscription,
|
||||
message_handler,
|
||||
)
|
||||
from autogen_core.components.models import ChatCompletionClient, CreateResult, SystemMessage, UserMessage
|
||||
from autogen_ext.models import ReplayChatCompletionClient
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: serialization_test.proto
|
||||
# Protobuf Python Version: 4.25.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18serialization_test.proto\x12\x06\x61gents"\x1f\n\x0cProtoMessage\x12\x0f\n\x07message\x18\x01 \x01(\t"L\n\x13NestingProtoMessage\x12\x0f\n\x07message\x18\x01 \x01(\t\x12$\n\x06nested\x18\x02 \x01(\x0b\x32\x14.agents.ProtoMessageb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "serialization_test_pb2", _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_globals["_PROTOMESSAGE"]._serialized_start=36
|
||||
_globals["_PROTOMESSAGE"]._serialized_end=67
|
||||
_globals["_NESTINGPROTOMESSAGE"]._serialized_start=69
|
||||
_globals["_NESTINGPROTOMESSAGE"]._serialized_end=145
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.message
|
||||
import typing
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class ProtoMessage(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MESSAGE_FIELD_NUMBER: builtins.int
|
||||
message: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["message", b"message"]) -> None: ...
|
||||
|
||||
global___ProtoMessage = ProtoMessage
|
||||
|
||||
@typing.final
|
||||
class NestingProtoMessage(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MESSAGE_FIELD_NUMBER: builtins.int
|
||||
NESTED_FIELD_NUMBER: builtins.int
|
||||
message: builtins.str
|
||||
@property
|
||||
def nested(self) -> global___ProtoMessage: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message: builtins.str = ...,
|
||||
nested: global___ProtoMessage | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["nested", b"nested"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["message", b"message", "nested", b"nested"]) -> None: ...
|
||||
|
||||
global___NestingProtoMessage = NestingProtoMessage
|
||||
@@ -19,9 +19,8 @@ from autogen_core import (
|
||||
try_get_known_serializers_for_type,
|
||||
type_subscription,
|
||||
)
|
||||
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
|
||||
from protos.serialization_test_pb2 import ProtoMessage
|
||||
from test_utils import (
|
||||
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost
|
||||
from autogen_test_utils import (
|
||||
CascadingAgent,
|
||||
CascadingMessageType,
|
||||
ContentMessage,
|
||||
@@ -30,15 +29,16 @@ from test_utils import (
|
||||
MessageType,
|
||||
NoopAgent,
|
||||
)
|
||||
from protos.serialization_test_pb2 import ProtoMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_types_must_be_unique_single_worker() -> None:
|
||||
host_address = "localhost:50051"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
|
||||
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
@@ -57,12 +57,12 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_types_must_be_unique_multiple_workers() -> None:
|
||||
host_address = "localhost:50052"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
worker2 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker2.start()
|
||||
|
||||
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
@@ -82,10 +82,10 @@ async def test_agent_types_must_be_unique_multiple_workers() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish() -> None:
|
||||
host_address = "localhost:50053"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1.start()
|
||||
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker1.register_factory(
|
||||
@@ -93,7 +93,7 @@ async def test_register_receives_publish() -> None:
|
||||
)
|
||||
await worker1.add_subscription(TypeSubscription("default", "name1"))
|
||||
|
||||
worker2 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker2.start()
|
||||
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
await worker2.register_factory(
|
||||
@@ -127,9 +127,9 @@ async def test_register_receives_publish() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_cascade_single_worker() -> None:
|
||||
host_address = "localhost:50054"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
|
||||
num_agents = 5
|
||||
@@ -164,7 +164,7 @@ async def test_register_receives_publish_cascade_single_worker() -> None:
|
||||
async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
host_address = "localhost:50055"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
|
||||
# TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail.
|
||||
@@ -176,16 +176,16 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
||||
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
|
||||
|
||||
# Run multiple workers one for each agent.
|
||||
workers: List[WorkerAgentRuntime] = []
|
||||
workers: List[GrpcWorkerAgentRuntime] = []
|
||||
# Register agents
|
||||
for i in range(num_agents):
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
|
||||
workers.append(runtime)
|
||||
|
||||
# Publish messages
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
|
||||
publisher.start()
|
||||
for _ in range(num_initial_messages):
|
||||
@@ -207,11 +207,11 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription() -> None:
|
||||
host_address = "localhost:50056"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
|
||||
@@ -241,11 +241,11 @@ async def test_default_subscription() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_subscription_other_source() -> None:
|
||||
host_address = "localhost:50057"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
runtime = WorkerAgentRuntime(host_address=host_address)
|
||||
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
runtime.start()
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
|
||||
@@ -275,11 +275,11 @@ async def test_default_subscription_other_source() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_type_subscription() -> None:
|
||||
host_address = "localhost:50058"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker = WorkerAgentRuntime(host_address=host_address)
|
||||
worker = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker.start()
|
||||
publisher = WorkerAgentRuntime(host_address=host_address)
|
||||
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
publisher.start()
|
||||
|
||||
@@ -312,9 +312,9 @@ async def test_type_subscription() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_subscription() -> None:
|
||||
host_address = "localhost:50059"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1_2 = WorkerAgentRuntime(host_address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
host.start()
|
||||
try:
|
||||
worker1.start()
|
||||
@@ -343,10 +343,10 @@ async def test_duplicate_subscription() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnected_agent() -> None:
|
||||
host_address = "localhost:50060"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1_2 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
|
||||
# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
|
||||
# to some private properties. This needs to be updated once they are available publicly
|
||||
@@ -421,13 +421,13 @@ class ProtoReceivingAgent(RoutedAgent):
|
||||
@pytest.mark.asyncio
|
||||
async def test_proto_payloads() -> None:
|
||||
host_address = "localhost:50057"
|
||||
host = WorkerAgentRuntimeHost(address=host_address)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address)
|
||||
host.start()
|
||||
receiver_runtime = WorkerAgentRuntime(
|
||||
receiver_runtime = GrpcWorkerAgentRuntime(
|
||||
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
|
||||
)
|
||||
receiver_runtime.start()
|
||||
publisher_runtime = WorkerAgentRuntime(
|
||||
publisher_runtime = GrpcWorkerAgentRuntime(
|
||||
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
|
||||
)
|
||||
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
|
||||
@@ -473,10 +473,10 @@ async def test_grpc_max_message_size() -> None:
|
||||
("grpc.max_receive_message_length", new_max_size),
|
||||
]
|
||||
host_address = "localhost:50061"
|
||||
host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
worker2 = WorkerAgentRuntime(host_address=host_address)
|
||||
worker3 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
|
||||
worker3 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
|
||||
|
||||
try:
|
||||
host.start()
|
||||
@@ -5,8 +5,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from autogen_core import AgentId, AgentProxy
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import AgentId, AgentProxy, SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_core.components.code_executor import CodeBlock
|
||||
from autogen_ext.code_executors import DockerCommandLineCodeExecutor
|
||||
|
||||
@@ -7,8 +7,7 @@ round-robin orchestrator agent. The code snippets are executed inside a docker c
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from autogen_core import AgentId, AgentProxy
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import AgentId, AgentProxy, SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_core.components.code_executor import CodeBlock
|
||||
from autogen_ext.code_executors import DockerCommandLineCodeExecutor
|
||||
|
||||
@@ -5,8 +5,7 @@ to write input or perform actions, orchestrated by an round-robin orchestrator a
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from autogen_core import AgentId, AgentProxy
|
||||
from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from autogen_core import AgentId, AgentProxy, SingleThreadedAgentRuntime
|
||||
from autogen_core.application.logging import EVENT_LOGGER_NAME
|
||||
from autogen_magentic_one.agents.file_surfer import FileSurfer
|
||||
from autogen_magentic_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user