Move grpc runtimes to ext, flatten application (#4553)

* Move grpc runtimes to ext, flatten application

* rename to grpc

* fmt
This commit is contained in:
Jack Gerrits
2024-12-04 16:23:20 -08:00
committed by GitHub
parent 777f2abbd7
commit 2b878763f8
113 changed files with 4712 additions and 4457 deletions

View File

@@ -9,7 +9,7 @@ from openai import AzureOpenAI
from typing import List
from autogen_core import AgentId, AgentProxy, TopicId
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_core.components.models import (
ChatCompletionClient,

View File

@@ -9,7 +9,7 @@ from openai import AzureOpenAI
from typing import List
from autogen_core import AgentId, AgentProxy, TopicId
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_core.components.models import (
ChatCompletionClient,

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
from autogen_core import AgentId, AgentProxy, TopicId
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_core import DefaultSubscription, DefaultTopicId
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor

View File

@@ -8,7 +8,7 @@ import nltk
from typing import Any, Dict, List, Tuple, Union
from autogen_core import AgentId, AgentProxy, TopicId
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_core import DefaultSubscription, DefaultTopicId
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor

View File

@@ -12,10 +12,10 @@ from autogen_core import (
CancellationToken,
ClosureAgent,
MessageContext,
SingleThreadedAgentRuntime,
TypeSubscription,
)
from autogen_core._closure_agent import ClosureContext
from autogen_core.application import SingleThreadedAgentRuntime
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition

View File

@@ -5,8 +5,14 @@ from typing import List
import pytest
from autogen_agentchat.teams._group_chat._sequential_routed_agent import SequentialRoutedAgent
from autogen_core import AgentId, DefaultTopicId, MessageContext, default_subscription, message_handler
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import (
AgentId,
DefaultTopicId,
MessageContext,
SingleThreadedAgentRuntime,
default_subscription,
message_handler,
)
@dataclass

View File

@@ -25,8 +25,14 @@
"import asyncio\n",
"from dataclasses import dataclass\n",
"\n",
"from autogen_core import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId, MessageContext\n",
"from autogen_core.application import SingleThreadedAgentRuntime"
"from autogen_core import (\n",
" ClosureAgent,\n",
" ClosureContext,\n",
" DefaultSubscription,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
" SingleThreadedAgentRuntime,\n",
")"
]
},
{

View File

@@ -35,15 +35,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any, Callable, List, Literal\n",
"\n",
"from autogen_core import AgentId, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"from langchain_core.tools import tool # pyright: ignore\n",

View File

@@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -41,8 +41,7 @@
"from dataclasses import dataclass\n",
"from typing import List, Optional\n",
"\n",
"from autogen_core import AgentId, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
"from azure.identity import DefaultAzureCredential, get_bearer_token_provider\n",
"from llama_index.core import Settings\n",
"from llama_index.core.agent import ReActAgent\n",

View File

@@ -39,8 +39,15 @@
"source": [
"from dataclasses import dataclass\n",
"\n",
"from autogen_core import AgentId, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" default_subscription,\n",
" message_handler,\n",
")\n",
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
"from autogen_core.components.models import (\n",
" AssistantMessage,\n",

View File

@@ -386,7 +386,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import SingleThreadedAgentRuntime\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"await OpenAIAssistantAgent.register(\n",

View File

@@ -22,8 +22,15 @@
"from dataclasses import dataclass\n",
"from typing import Any\n",
"\n",
"from autogen_core import AgentId, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" default_subscription,\n",
" message_handler,\n",
")\n",
"from autogen_core.base.intervention import DefaultInterventionHandler"
]
},

View File

@@ -1,283 +1,290 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# User Approval for Tool Execution using Intervention Handler\n",
"\n",
"This cookbook shows how to intercept the tool execution using\n",
"an intervention hanlder, and prompt the user for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any, List\n",
"\n",
"from autogen_core import AgentId, AgentType, FunctionCall, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a simple message type that carries a string content."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" content: str"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a simple tool use agent that is capable of using tools through a\n",
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"class ToolUseAgent(RoutedAgent):\n",
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" description: str,\n",
" system_messages: List[SystemMessage],\n",
" model_client: ChatCompletionClient,\n",
" tool_schema: List[ToolSchema],\n",
" tool_agent_type: AgentType,\n",
" ) -> None:\n",
" super().__init__(description)\n",
" self._model_client = model_client\n",
" self._system_messages = system_messages\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
" # Use the tool agent to execute the tools, and get the output messages.\n",
" output_messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Extract the final response from the output messages.\n",
" final_response = output_messages[-1].content\n",
" assert isinstance(final_response, str)\n",
" return Message(content=final_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
"to prompt the user for permission to execute the tool.\n",
"\n",
"Let's create an intervention handler that intercepts the messages and prompts\n",
"user for before allowing the tool execution."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
" )\n",
" if user_input.strip().lower() != \"y\":\n",
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
" return message"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a runtime with the intervention handler registered."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# Create the runtime with the intervention handler.\n",
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we will use a tool for Python code execution.\n",
"First, we create a Docker-based command-line code executor\n",
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
"and then use it to instantiate a built-in Python code execution tool\n",
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
"that runs code in a Docker container."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# Create the docker executor for the Python code execution tool.\n",
"docker_executor = DockerCommandLineCodeExecutor()\n",
"\n",
"# Create the Python code execution tool.\n",
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register the agents with tools and tool schema."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_enabled_agent')"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Register agents.\n",
"tool_agent_type = await ToolAgent.register(\n",
" runtime,\n",
" \"tool_executor_agent\",\n",
" lambda: ToolAgent(\n",
" description=\"Tool Executor Agent\",\n",
" tools=[python_tool],\n",
" ),\n",
")\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_enabled_agent\",\n",
" lambda: ToolUseAgent(\n",
" description=\"Tool Use Agent\",\n",
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
" tool_schema=[python_tool.schema],\n",
" tool_agent_type=tool_agent_type,\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
"The intervention handler will prompt you for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output of the code is: **Hello, World!**\n"
]
}
],
"source": [
"# Start the runtime and the docker executor.\n",
"await docker_executor.start()\n",
"runtime.start()\n",
"\n",
"# Send a task to the tool user.\n",
"response = await runtime.send_message(\n",
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
")\n",
"print(response.content)\n",
"\n",
"# Stop the runtime and the docker executor.\n",
"await runtime.stop()\n",
"await docker_executor.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# User Approval for Tool Execution using Intervention Handler\n",
"\n",
"This cookbook shows how to intercept the tool execution using\n",
"an intervention hanlder, and prompt the user for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any, List\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" AgentType,\n",
" FunctionCall,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" message_handler,\n",
")\n",
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a simple message type that carries a string content."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" content: str"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a simple tool use agent that is capable of using tools through a\n",
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"class ToolUseAgent(RoutedAgent):\n",
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" description: str,\n",
" system_messages: List[SystemMessage],\n",
" model_client: ChatCompletionClient,\n",
" tool_schema: List[ToolSchema],\n",
" tool_agent_type: AgentType,\n",
" ) -> None:\n",
" super().__init__(description)\n",
" self._model_client = model_client\n",
" self._system_messages = system_messages\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
" # Use the tool agent to execute the tools, and get the output messages.\n",
" output_messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Extract the final response from the output messages.\n",
" final_response = output_messages[-1].content\n",
" assert isinstance(final_response, str)\n",
" return Message(content=final_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
"to prompt the user for permission to execute the tool.\n",
"\n",
"Let's create an intervention handler that intercepts the messages and prompts\n",
"user for before allowing the tool execution."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
" )\n",
" if user_input.strip().lower() != \"y\":\n",
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
" return message"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a runtime with the intervention handler registered."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# Create the runtime with the intervention handler.\n",
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we will use a tool for Python code execution.\n",
"First, we create a Docker-based command-line code executor\n",
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
"and then use it to instantiate a built-in Python code execution tool\n",
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
"that runs code in a Docker container."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# Create the docker executor for the Python code execution tool.\n",
"docker_executor = DockerCommandLineCodeExecutor()\n",
"\n",
"# Create the Python code execution tool.\n",
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register the agents with tools and tool schema."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_enabled_agent')"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Register agents.\n",
"tool_agent_type = await ToolAgent.register(\n",
" runtime,\n",
" \"tool_executor_agent\",\n",
" lambda: ToolAgent(\n",
" description=\"Tool Executor Agent\",\n",
" tools=[python_tool],\n",
" ),\n",
")\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_enabled_agent\",\n",
" lambda: ToolUseAgent(\n",
" description=\"Tool Use Agent\",\n",
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
" tool_schema=[python_tool.schema],\n",
" tool_agent_type=tool_agent_type,\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
"The intervention handler will prompt you for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output of the code is: **Hello, World!**\n"
]
}
],
"source": [
"# Start the runtime and the docker executor.\n",
"await docker_executor.start()\n",
"runtime.start()\n",
"\n",
"# Send a task to the tool user.\n",
"response = await runtime.send_message(\n",
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
")\n",
"print(response.content)\n",
"\n",
"# Stop the runtime and the docker executor.\n",
"await runtime.stop()\n",
"await docker_executor.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -34,13 +34,13 @@
" DefaultTopicId,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" TopicId,\n",
" TypeSubscription,\n",
" default_subscription,\n",
" message_handler,\n",
" type_subscription,\n",
")\n",
"from autogen_core.application import SingleThreadedAgentRuntime"
")"
]
},
{

View File

@@ -78,11 +78,11 @@
" Image,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" TopicId,\n",
" TypeSubscription,\n",
" message_handler,\n",
")\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.components.models import (\n",
" AssistantMessage,\n",
" ChatCompletionClient,\n",

View File

@@ -56,8 +56,15 @@
"import uuid\n",
"from typing import List, Tuple\n",
"\n",
"from autogen_core import FunctionCall, MessageContext, RoutedAgent, TopicId, TypeSubscription, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import (\n",
" FunctionCall,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" TopicId,\n",
" TypeSubscription,\n",
" message_handler,\n",
")\n",
"from autogen_core.components.models import (\n",
" AssistantMessage,\n",
" ChatCompletionClient,\n",

View File

@@ -441,8 +441,7 @@
}
],
"source": [
"from autogen_core import DefaultTopicId\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import DefaultTopicId, SingleThreadedAgentRuntime\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",

View File

@@ -18,7 +18,7 @@ The key can correspond to a user id, a session id, or could just be "default" if
## How do I increase the GRPC message size?
If you need to provide custom gRPC options, such as overriding the `max_send_message_length` and `max_receive_message_length`, you can define an `extra_grpc_config` variable and pass it to both the `WorkerAgentRuntimeHost` and `WorkerAgentRuntime` instances.
If you need to provide custom gRPC options, such as overriding the `max_send_message_length` and `max_receive_message_length`, you can define an `extra_grpc_config` variable and pass it to both the `GrpcWorkerAgentRuntimeHost` and `GrpcWorkerAgentRuntime` instances.
```python
# Define custom gRPC options
@@ -27,10 +27,10 @@ extra_grpc_config = [
("grpc.max_receive_message_length", new_max_size),
]
# Create instances of WorkerAgentRuntimeHost and WorkerAgentRuntime with the custom gRPC options
# Create instances of GrpcWorkerAgentRuntimeHost and GrpcWorkerAgentRuntime with the custom gRPC options
host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
```
**Note**: When `WorkerAgentRuntime` creates a host connection for the clients, it uses `DEFAULT_GRPC_CONFIG` from `HostConnection` class as default set of values which will can be overriden if you pass parameters with the same name using `extra_grpc_config`.
**Note**: When `GrpcWorkerAgentRuntime` creates a host connection for the clients, it uses `DEFAULT_GRPC_CONFIG` from `HostConnection` class as default set of values which will can be overriden if you pass parameters with the same name using `extra_grpc_config`.

View File

@@ -117,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -132,7 +132,7 @@
}
],
"source": [
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import SingleThreadedAgentRuntime\n",
"\n",
"runtime = SingleThreadedAgentRuntime()\n",
"await MyAgent.register(runtime, \"my_agent\", lambda: MyAgent())"

View File

@@ -28,18 +28,18 @@
"```\n",
"````\n",
"\n",
"We can start a host service using {py:class}`~autogen_core.application.WorkerAgentRuntimeHost`."
"We can start a host service using {py:class}`~autogen_core.application.GrpcWorkerAgentRuntimeHost`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.application import WorkerAgentRuntimeHost\n",
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n",
"\n",
"host = WorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
"host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n",
"host.start() # Start a host service in the background."
]
},
@@ -94,7 +94,7 @@
"metadata": {},
"source": [
"Now we can set up the worker agent runtimes.\n",
"We use {py:class}`~autogen_core.application.WorkerAgentRuntime`.\n",
"We use {py:class}`~autogen_core.application.GrpcWorkerAgentRuntime`.\n",
"We set up two worker runtimes. Each runtime hosts one agent.\n",
"All agents publish and subscribe to the default topic, so they can see all\n",
"messages being published.\n",
@@ -104,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -127,13 +127,13 @@
"source": [
"import asyncio\n",
"\n",
"from autogen_core.application import WorkerAgentRuntime\n",
"from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n",
"\n",
"worker1 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker1.start()\n",
"await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n",
"\n",
"worker2 = WorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n",
"worker2.start()\n",
"await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n",
"\n",
@@ -149,7 +149,7 @@
"source": [
"We can see each agent published exactly 5 messages.\n",
"\n",
"To stop the worker runtimes, we can call {py:meth}`~autogen_core.application.WorkerAgentRuntime.stop`."
"To stop the worker runtimes, we can call {py:meth}`~autogen_core.application.GrpcWorkerAgentRuntime.stop`."
]
},
{
@@ -169,7 +169,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can call {py:meth}`~autogen_core.application.WorkerAgentRuntimeHost.stop`\n",
"We can call {py:meth}`~autogen_core.application.GrpcWorkerAgentRuntimeHost.stop`\n",
"to stop the host service."
]
},

View File

@@ -90,8 +90,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core import AgentId, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
"\n",
"\n",
"class MyAgent(RoutedAgent):\n",
@@ -298,8 +297,7 @@
"source": [
"from dataclasses import dataclass\n",
"\n",
"from autogen_core import MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
"\n",
"\n",
"@dataclass\n",

View File

@@ -48,7 +48,7 @@ Now you can send the trace_provider when creating your runtime:
# for single threaded runtime
single_threaded_runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
# or for worker runtime
worker_runtime = WorkerAgentRuntime(tracer_provider=tracer_provider)
worker_runtime = GrpcWorkerAgentRuntime(tracer_provider=tracer_provider)
```
And that's it! Your application is now instrumented with open telemetry. You can now view your telemetry data in your telemetry backend.
@@ -65,5 +65,5 @@ tracer_provider = trace.get_tracer_provider()
# for single threaded runtime
single_threaded_runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
# or for worker runtime
worker_runtime = WorkerAgentRuntime(tracer_provider=tracer_provider)
worker_runtime = GrpcWorkerAgentRuntime(tracer_provider=tracer_provider)
```

View File

@@ -1,315 +1,321 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tools\n",
"\n",
"Tools are code that can be executed by an agent to perform actions. A tool\n",
"can be a simple function such as a calculator, or an API call to a third-party service\n",
"such as stock price lookup or weather forecast.\n",
"In the context of AI agents, tools are designed to be executed by agents in\n",
"response to model-generated function calls.\n",
"\n",
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
"tools and utilities for creating and running custom tools."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Built-in Tools\n",
"\n",
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
"which allows agents to execute Python code snippets.\n",
"\n",
"Here is how you create the tool and use it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, world!\n",
"\n"
]
}
],
"source": [
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"\n",
"# Create the tool.\n",
"code_executor = DockerCommandLineCodeExecutor()\n",
"await code_executor.start()\n",
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
"cancellation_token = CancellationToken()\n",
"\n",
"# Use the tool directly without an agent.\n",
"code = \"print('Hello, world!')\"\n",
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
"print(code_execution_tool.return_value_as_string(result))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
"in the local command line environment.\n",
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
"and provides a simple interface to execute Python code snippets.\n",
"\n",
"Other built-in tools will be added in the future."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Function Tools\n",
"\n",
"A tool can also be a simple Python function that performs a specific action.\n",
"To create a custom function tool, you just need to create a Python function\n",
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
"\n",
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
"to inform the LLM when and how to use a given function. The description provides context\n",
"about the functions purpose and intended use cases, while type annotations inform the LLM about\n",
"the expected parameters and return type.\n",
"\n",
"For example, a simple tool to obtain the stock price of a company might look like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"80.44429939059668\n"
]
}
],
"source": [
"import random\n",
"\n",
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
"from typing_extensions import Annotated\n",
"\n",
"\n",
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
" # Returns a random stock price for demonstration purposes.\n",
" return random.uniform(10, 200)\n",
"\n",
"\n",
"# Create a function tool.\n",
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
"\n",
"# Run the tool.\n",
"cancellation_token = CancellationToken()\n",
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
"\n",
"# Print the result.\n",
"print(stock_price_tool.return_value_as_string(result))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool-Equipped Agent\n",
"\n",
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
"by using it in a composition pattern.\n",
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
"as an inner agent for executing tools."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import List\n",
"\n",
"from autogen_core import AgentId, AgentInstantiationContext, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"\n",
"@dataclass\n",
"class Message:\n",
" content: str\n",
"\n",
"\n",
"class ToolUseAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
" super().__init__(\"An agent with tools\")\n",
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
" self._model_client = model_client\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" # Create a session of messages.\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
" # Run the caller loop to handle tool calls.\n",
" messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Return the final response.\n",
" assert isinstance(messages[-1].content, str)\n",
" return Message(content=messages[-1].content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
"to handle the interaction between the model and the tool agent.\n",
"The core idea can be described using a simple control flow graph:\n",
"\n",
"![ToolUseAgent control flow graph](tool-use-agent-cfg.svg)\n",
"\n",
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
"and determines whether the model has generated a tool call.\n",
"If the model has generated tool calls, then the handler sends a function call\n",
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
"to execute the tools,\n",
"and then queries the model again with the results of the tool calls.\n",
"This process continues until the model stops generating tool calls,\n",
"at which point the final response is returned to the user.\n",
"\n",
"By having the tool execution logic in a separate agent,\n",
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
"can be observed externally and intercepted if necessary.\n",
"\n",
"To run the agent, we need to create a runtime and register the agent."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_use_agent')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a runtime.\n",
"runtime = SingleThreadedAgentRuntime()\n",
"# Create the tools.\n",
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
"# Register the agents.\n",
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_use_agent\",\n",
" lambda: ToolUseAgent(\n",
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
"Let's test the agent with a question about stock price."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
]
}
],
"source": [
"# Start processing messages.\n",
"runtime.start()\n",
"# Send a direct message to the tool agent.\n",
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
"print(response.content)\n",
"# Stop processing messages.\n",
"await runtime.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen_core",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tools\n",
"\n",
"Tools are code that can be executed by an agent to perform actions. A tool\n",
"can be a simple function such as a calculator, or an API call to a third-party service\n",
"such as stock price lookup or weather forecast.\n",
"In the context of AI agents, tools are designed to be executed by agents in\n",
"response to model-generated function calls.\n",
"\n",
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
"tools and utilities for creating and running custom tools."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Built-in Tools\n",
"\n",
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
"which allows agents to execute Python code snippets.\n",
"\n",
"Here is how you create the tool and use it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, world!\n",
"\n"
]
}
],
"source": [
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"\n",
"# Create the tool.\n",
"code_executor = DockerCommandLineCodeExecutor()\n",
"await code_executor.start()\n",
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
"cancellation_token = CancellationToken()\n",
"\n",
"# Use the tool directly without an agent.\n",
"code = \"print('Hello, world!')\"\n",
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
"print(code_execution_tool.return_value_as_string(result))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
"in the local command line environment.\n",
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
"and provides a simple interface to execute Python code snippets.\n",
"\n",
"Other built-in tools will be added in the future."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Function Tools\n",
"\n",
"A tool can also be a simple Python function that performs a specific action.\n",
"To create a custom function tool, you just need to create a Python function\n",
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
"\n",
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
"to inform the LLM when and how to use a given function. The description provides context\n",
"about the functions purpose and intended use cases, while type annotations inform the LLM about\n",
"the expected parameters and return type.\n",
"\n",
"For example, a simple tool to obtain the stock price of a company might look like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"80.44429939059668\n"
]
}
],
"source": [
"import random\n",
"\n",
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
"from typing_extensions import Annotated\n",
"\n",
"\n",
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
" # Returns a random stock price for demonstration purposes.\n",
" return random.uniform(10, 200)\n",
"\n",
"\n",
"# Create a function tool.\n",
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
"\n",
"# Run the tool.\n",
"cancellation_token = CancellationToken()\n",
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
"\n",
"# Print the result.\n",
"print(stock_price_tool.return_value_as_string(result))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool-Equipped Agent\n",
"\n",
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
"by using it in a composition pattern.\n",
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
"as an inner agent for executing tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import List\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" AgentInstantiationContext,\n",
" MessageContext,\n",
" RoutedAgent,\n",
" SingleThreadedAgentRuntime,\n",
" message_handler,\n",
")\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"\n",
"@dataclass\n",
"class Message:\n",
" content: str\n",
"\n",
"\n",
"class ToolUseAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
" super().__init__(\"An agent with tools\")\n",
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
" self._model_client = model_client\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" # Create a session of messages.\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
" # Run the caller loop to handle tool calls.\n",
" messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Return the final response.\n",
" assert isinstance(messages[-1].content, str)\n",
" return Message(content=messages[-1].content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
"to handle the interaction between the model and the tool agent.\n",
"The core idea can be described using a simple control flow graph:\n",
"\n",
"![ToolUseAgent control flow graph](tool-use-agent-cfg.svg)\n",
"\n",
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
"and determines whether the model has generated a tool call.\n",
"If the model has generated tool calls, then the handler sends a function call\n",
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
"to execute the tools,\n",
"and then queries the model again with the results of the tool calls.\n",
"This process continues until the model stops generating tool calls,\n",
"at which point the final response is returned to the user.\n",
"\n",
"By having the tool execution logic in a separate agent,\n",
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
"can be observed externally and intercepted if necessary.\n",
"\n",
"To run the agent, we need to create a runtime and register the agent."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_use_agent')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a runtime.\n",
"runtime = SingleThreadedAgentRuntime()\n",
"# Create the tools.\n",
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
"# Register the agents.\n",
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_use_agent\",\n",
" lambda: ToolUseAgent(\n",
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
"Let's test the agent with a question about stock price."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
]
}
],
"source": [
"# Start processing messages.\n",
"runtime.start()\n",
"# Send a direct message to the tool agent.\n",
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
"print(response.content)\n",
"# Stop processing messages.\n",
"await runtime.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen_core",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -310,7 +310,7 @@
"source": [
"import tempfile\n",
"\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core import SingleThreadedAgentRuntime\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",

View File

@@ -25,7 +25,6 @@ dependencies = [
"opentelemetry-api~=1.27.0",
"asyncio_atexit",
"jsonref~=1.1.0",
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
[project.optional-dependencies]
@@ -35,6 +34,7 @@ grpc = [
[tool.uv]
dev-dependencies = [
"autogen_test_utils",
"aiofiles",
"azure-identity",
"chess",
@@ -75,6 +75,8 @@ dev-dependencies = [
"autodoc_pydantic~=2.2",
"pygments",
"autogen_ext==0.4.0.dev8",
# Documentation tooling
"sphinx-autobuild",
]

View File

@@ -7,8 +7,14 @@ import asyncio
import logging
from typing import Annotated, Literal
from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime, DefaultSubscription, DefaultTopicId
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import (
AgentId,
AgentInstantiationContext,
AgentRuntime,
DefaultSubscription,
DefaultTopicId,
SingleThreadedAgentRuntime,
)
from autogen_core.components.model_context import BufferedChatCompletionContext
from autogen_core.components.models import SystemMessage
from autogen_core.components.tools import FunctionTool

View File

@@ -1,8 +1,6 @@
# Distributed Group Chat
from autogen_core.application import WorkerAgentRuntimeHost
This example runs a gRPC server using [WorkerAgentRuntimeHost](../../src/autogen_core/application/_worker_runtime_host.py) and instantiates three distributed runtimes using [WorkerAgentRuntime](../../src/autogen_core/application/_worker_runtime.py). These runtimes connect to the gRPC server as hosts and facilitate a round-robin distributed group chat. This example leverages the [Azure OpenAI Service](https://azure.microsoft.com/en-us/products/ai-services/openai-service) to implement writer and editor LLM agents. Agents are instructed to provide concise answers, as the primary goal of this example is to showcase the distributed runtime rather than the quality of agent responses.
This example runs a gRPC server using [GrpcWorkerAgentRuntimeHost](../../src/autogen_core/application/_worker_runtime_host.py) and instantiates three distributed runtimes using [GrpcWorkerAgentRuntime](../../src/autogen_core/application/_worker_runtime.py). These runtimes connect to the gRPC server as hosts and facilitate a round-robin distributed group chat. This example leverages the [Azure OpenAI Service](https://azure.microsoft.com/en-us/products/ai-services/openai-service) to implement writer and editor LLM agents. Agents are instructed to provide concise answers, as the primary goal of this example is to showcase the distributed runtime rather than the quality of agent responses.
## Setup

View File

@@ -5,7 +5,6 @@ from uuid import uuid4
from _types import GroupChatMessage, MessageChunk, RequestToSpeak, UIAgentConfig
from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, message_handler
from autogen_core.application import WorkerAgentRuntime
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
@@ -13,6 +12,7 @@ from autogen_core.components.models import (
SystemMessage,
UserMessage,
)
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
from rich.console import Console
from rich.markdown import Markdown
@@ -168,7 +168,7 @@ class UIAgent(RoutedAgent):
async def publish_message_to_ui(
runtime: RoutedAgent | WorkerAgentRuntime,
runtime: RoutedAgent | GrpcWorkerAgentRuntime,
source: str,
user_message: str,
ui_config: UIAgentConfig,
@@ -193,7 +193,7 @@ async def publish_message_to_ui(
async def publish_message_to_ui_and_backend(
runtime: RoutedAgent | WorkerAgentRuntime,
runtime: RoutedAgent | GrpcWorkerAgentRuntime,
source: str,
user_message: str,
ui_config: UIAgentConfig,

View File

@@ -8,15 +8,15 @@ from _utils import get_serializers, load_config, set_all_log_levels
from autogen_core import (
TypeSubscription,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.models import AzureOpenAIChatCompletionClient
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
from rich.console import Console
from rich.markdown import Markdown
async def main(config: AppConfig):
set_all_log_levels(logging.ERROR)
editor_agent_runtime = WorkerAgentRuntime(host_address=config.host.address)
editor_agent_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
editor_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(4)
Console().print(Markdown("Starting **`Editor Agent`**"))

View File

@@ -8,8 +8,8 @@ from _utils import get_serializers, load_config, set_all_log_levels
from autogen_core import (
TypeSubscription,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.models import AzureOpenAIChatCompletionClient
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
from rich.console import Console
from rich.markdown import Markdown
@@ -18,7 +18,7 @@ set_all_log_levels(logging.ERROR)
async def main(config: AppConfig):
set_all_log_levels(logging.ERROR)
group_chat_manager_runtime = WorkerAgentRuntime(host_address=config.host.address)
group_chat_manager_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
group_chat_manager_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(1)

View File

@@ -2,13 +2,13 @@ import asyncio
from _types import HostConfig
from _utils import load_config
from autogen_core.application import WorkerAgentRuntimeHost
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost
from rich.console import Console
from rich.markdown import Markdown
async def main(host_config: HostConfig):
host = WorkerAgentRuntimeHost(address=host_config.address)
host = GrpcWorkerAgentRuntimeHost(address=host_config.address)
host.start()
console = Console()

View File

@@ -9,7 +9,7 @@ from _utils import get_serializers, load_config, set_all_log_levels
from autogen_core import (
TypeSubscription,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
from chainlit import Message # type: ignore [reportAttributeAccessIssue]
from rich.console import Console
from rich.markdown import Markdown
@@ -36,7 +36,7 @@ async def send_cl_stream(msg: MessageChunk) -> None:
async def main(config: AppConfig):
set_all_log_levels(logging.ERROR)
ui_agent_runtime = WorkerAgentRuntime(host_address=config.host.address)
ui_agent_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
ui_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]

View File

@@ -8,15 +8,15 @@ from _utils import get_serializers, load_config, set_all_log_levels
from autogen_core import (
TypeSubscription,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.models import AzureOpenAIChatCompletionClient
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
from rich.console import Console
from rich.markdown import Markdown
async def main(config: AppConfig) -> None:
set_all_log_levels(logging.ERROR)
writer_agent_runtime = WorkerAgentRuntime(host_address=config.host.address)
writer_agent_runtime = GrpcWorkerAgentRuntime(host_address=config.host.address)
writer_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(3)
Console().print(Markdown("Starting **`Writer Agent`**"))

View File

@@ -2,12 +2,12 @@ import asyncio
import logging
import platform
from autogen_core.application import WorkerAgentRuntimeHost
from autogen_core.application.logging import TRACE_LOGGER_NAME
from autogen_core import TRACE_LOGGER_NAME
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost
async def run_host():
host = WorkerAgentRuntimeHost(address="localhost:50051")
host = GrpcWorkerAgentRuntimeHost(address="localhost:50051")
host.start() # Start a host service in the background.
if platform.system() == "Windows":
try:

View File

@@ -32,7 +32,7 @@ from _semantic_router_components import (
WorkerAgentMessage,
)
from autogen_core import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId, MessageContext
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
class MockIntentClassifier(IntentClassifierBase):
@@ -78,7 +78,7 @@ async def output_result(
async def run_workers():
agent_runtime = WorkerAgentRuntime(host_address="localhost:50051")
agent_runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
agent_runtime.start()

View File

@@ -37,10 +37,10 @@ from autogen_core import (
FunctionCall,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
message_handler,
type_subscription,
)
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base.intervention import DefaultInterventionHandler
from autogen_core.components.model_context import BufferedChatCompletionContext
from autogen_core.components.models import (

View File

@@ -1,10 +1,10 @@
from agents import CascadingMessage, ObserverAgent
from autogen_core import DefaultTopicId, try_get_known_serializers_for_type
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
runtime.start()
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())

View File

@@ -2,11 +2,11 @@ import uuid
from agents import CascadingAgent, ReceiveMessageEvent
from autogen_core import try_get_known_serializers_for_type
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
runtime.start()
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")

View File

@@ -1,10 +1,10 @@
import asyncio
from autogen_core.application import WorkerAgentRuntimeHost
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost
async def main() -> None:
service = WorkerAgentRuntimeHost(address="localhost:50051")
service = GrpcWorkerAgentRuntimeHost(address="localhost:50051")
service.start()
await service.stop_when_signal()

View File

@@ -11,7 +11,7 @@ from autogen_core import (
message_handler,
try_get_known_serializers_for_type,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
@dataclass
@@ -72,7 +72,7 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
for t in [AskToGreet, Greeting, ReturnedGreeting, Feedback, ReturnedFeedback]:
runtime.add_message_serializer(try_get_known_serializers_for_type(t))

View File

@@ -10,7 +10,7 @@ from autogen_core import (
RoutedAgent,
message_handler,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
@dataclass
@@ -53,7 +53,7 @@ class GreeterAgent(RoutedAgent):
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
await ReceiveAgent.register(

View File

@@ -12,7 +12,7 @@ from autogen_core import (
TypeSubscription,
try_get_known_serializers_for_type,
)
from autogen_core.application import WorkerAgentRuntime
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime
# Add the local package directory to sys.path
thisdir = os.path.dirname(os.path.abspath(__file__))
@@ -34,7 +34,7 @@ async def main() -> None:
agentHost = agentHost[8:]
agnext_logger.info("0")
agnext_logger.info(agentHost)
runtime = WorkerAgentRuntime(host_address=agentHost, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE)
runtime = GrpcWorkerAgentRuntime(host_address=agentHost, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE)
agnext_logger.info("1")
runtime.start()

View File

@@ -12,6 +12,7 @@ from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._closure_agent import ClosureAgent, ClosureContext
from ._constants import EVENT_LOGGER_NAME, ROOT_LOGGER_NAME, TRACE_LOGGER_NAME
from ._default_subscription import DefaultSubscription, default_subscription, type_subscription
from ._default_topic import DefaultTopicId
from ._image import Image
@@ -25,6 +26,7 @@ from ._serialization import (
UnknownPayload,
try_get_known_serializers_for_type,
)
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
from ._subscription import Subscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
@@ -66,4 +68,8 @@ __all__ = [
"TypePrefixSubscription",
"JSON_DATA_CONTENT_TYPE",
"PROTOBUF_DATA_CONTENT_TYPE",
"SingleThreadedAgentRuntime",
"ROOT_LOGGER_NAME",
"EVENT_LOGGER_NAME",
"TRACE_LOGGER_NAME",
]

View File

@@ -0,0 +1,9 @@
ROOT_LOGGER_NAME = "autogen_core"
"""str: Logger name used for structured event logging"""
EVENT_LOGGER_NAME = "autogen_core.events"
"""str: Logger name used for structured event logging"""
TRACE_LOGGER_NAME = "autogen_core.trace"
"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""

View File

@@ -1,11 +1,11 @@
from collections import defaultdict
from typing import Awaitable, Callable, DefaultDict, List, Set
from .._agent import Agent
from .._agent_id import AgentId
from .._agent_type import AgentType
from .._subscription import Subscription
from .._topic import TopicId
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_type import AgentType
from ._subscription import Subscription
from ._topic import TopicId
async def get_impl(

View File

@@ -0,0 +1,687 @@
from __future__ import annotations
import asyncio
import inspect
import logging
import threading
import uuid
import warnings
from asyncio import CancelledError, Future, Task
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from opentelemetry.trace import TracerProvider
from typing_extensions import deprecated
from autogen_core._serialization import MessageSerializer, SerializationRegistry
from ._agent import Agent
from ._agent_id import AgentId
from ._agent_instantiation import AgentInstantiationContext
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._runtime_impl_helpers import SubscriptionManager, get_impl
from ._subscription import Subscription
from ._subscription_context import SubscriptionInstantiationContext
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
from ._topic import TopicId
from .base.intervention import DropMessage, InterventionHandler
from .exceptions import MessageDroppedException
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
# We use a type parameter in some functions which shadows the built-in `type` function.
# This is a workaround to avoid shadowing the built-in `type` function.
type_func_alias = type
@dataclass(kw_only=True)
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""
message: Any
cancellation_token: CancellationToken
sender: AgentId | None
topic_id: TopicId
metadata: EnvelopeMetadata | None = None
message_id: str
@dataclass(kw_only=True)
class SendMessageEnvelope:
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""
message: Any
sender: AgentId | None
recipient: AgentId
future: Future[Any]
cancellation_token: CancellationToken
metadata: EnvelopeMetadata | None = None
@dataclass(kw_only=True)
class ResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Any
future: Future[Any]
sender: AgentId
recipient: AgentId | None
metadata: EnvelopeMetadata | None = None
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
class Counter:
def __init__(self) -> None:
self._count: int = 0
self.threadLock = threading.Lock()
def increment(self) -> None:
self.threadLock.acquire()
self._count += 1
self.threadLock.release()
def get(self) -> int:
return self._count
def decrement(self) -> None:
self.threadLock.acquire()
self._count -= 1
self.threadLock.release()
class RunContext:
class RunState(Enum):
RUNNING = 0
CANCELLED = 1
UNTIL_IDLE = 2
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
self._run_state = RunContext.RunState.RUNNING
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
self._run_task = asyncio.create_task(self._run())
self._lock = asyncio.Lock()
async def _run(self) -> None:
while True:
async with self._lock:
if self._end_condition():
return
await self._runtime.process_next()
async def stop(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.CANCELLED
self._end_condition = self._stop_when_cancelled
await self._run_task
async def stop_when_idle(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.UNTIL_IDLE
self._end_condition = self._stop_when_idle
await self._run_task
async def stop_when(self, condition: Callable[[], bool]) -> None:
async with self._lock:
self._end_condition = condition
await self._run_task
def _stop_when_cancelled(self) -> bool:
return self._run_state == RunContext.RunState.CANCELLED
def _stop_when_idle(self) -> bool:
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
def _warn_if_none(value: Any, handler_name: str) -> None:
"""
Utility function to check if the intervention handler returned None and issue a warning.
Args:
value: The return value to check
handler_name: Name of the intervention handler method for the warning message
"""
if value is None:
warnings.warn(
f"Intervention handler {handler_name} returned None. This might be unintentional. "
"Consider returning the original message or DropMessage explicitly.",
RuntimeWarning,
stacklevel=2,
)
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(
self,
*,
intervention_handlers: List[InterventionHandler] | None = None,
tracer_provider: TracerProvider | None = None,
) -> None:
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._intervention_handlers = intervention_handlers
self._outstanding_tasks = Counter()
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
self._serialization_registry = SerializationRegistry()
@property
def unprocessed_messages(
self,
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
@property
def outstanding_tasks(self) -> int:
return self._outstanding_tasks.get()
@property
def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())
# Returns the response of the message
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
with self._tracer_helper.trace_block(
"create",
recipient,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
future = asyncio.get_event_loop().create_future()
if recipient.type not in self._known_agent_names:
future.set_exception(Exception("Recipient not found"))
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
metadata=get_telemetry_envelope_metadata(),
)
)
cancellation_token.link_future(future)
return await future
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> None:
with self._tracer_helper.trace_block(
"create",
topic_id,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
if cancellation_token is None:
cancellation_token = CancellationToken()
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}")
if message_id is None:
message_id = str(uuid.uuid4())
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
self._message_queue.append(
PublishMessageEnvelope(
message=message,
cancellation_token=cancellation_token,
sender=sender,
topic_id=topic_id,
metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
)
)
async def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent_id in self._instantiated_agents:
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
return state
async def load_state(self, state: Mapping[str, Any]) -> None:
for agent_id_str in state:
agent_id = AgentId.from_str(agent_id_str)
if agent_id.type in self._known_agent_names:
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
recipient = message_envelope.recipient
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
try:
# TODO use id
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
recipient_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
# Will be fixed when send API removed
message_id="NOT_DEFINED_TODO_FIX",
)
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
except CancelledError as e:
if not message_envelope.future.cancelled():
message_envelope.future.set_exception(e)
self._outstanding_tasks.decrement()
return
except BaseException as e:
message_envelope.future.set_exception(e)
self._outstanding_tasks.decrement()
return
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
metadata=get_telemetry_envelope_metadata(),
)
)
self._outstanding_tasks.decrement()
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
try:
responses: List[Awaitable[Any]] = []
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
for agent_id in recipients:
# Avoid sending the message back to the sender
if message_envelope.sender is not None and agent_id == message_envelope.sender:
continue
sender_agent = (
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
)
sender_name = str(sender_agent.id) if sender_agent is not None else "Unknown"
logger.info(
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=message_envelope.topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
agent = await self._get_agent(agent_id)
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
with self._tracer_helper.trace_block("process", agent.id, parent=None):
with MessageHandlerContext.populate_context(agent.id):
return await agent.on_message(
message_envelope.message,
ctx=message_context,
)
future = _on_message(agent, message_context)
responses.append(future)
await asyncio.gather(*responses)
except BaseException as e:
# Ignore cancelled errors from logs
if isinstance(e, CancelledError):
return
logger.error("Error processing publish message", exc_info=True)
finally:
self._outstanding_tasks.decrement()
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata):
content = (
message_envelope.message.__dict__
if hasattr(message_envelope.message, "__dict__")
else message_envelope.message
)
logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
self._outstanding_tasks.decrement()
if not message_envelope.future.cancelled():
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:
"""Process the next message in the queue."""
if len(self._message_queue) == 0:
# Yield control to the event loop to allow other tasks to run
await asyncio.sleep(0)
return
message_envelope = self._message_queue.pop(0)
match message_envelope:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
task = asyncio.create_task(self._process_send(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case PublishMessageEnvelope(
message=message,
sender=sender,
):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_publish(message, sender=sender)
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
# TODO log message dropped
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
try:
temp_message = await handler.on_response(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_response")
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
task = asyncio.create_task(self._process_response(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
@property
def idle(self) -> bool:
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
def start(self) -> None:
"""Start the runtime message processing loop."""
if self._run_context is not None:
raise RuntimeError("Runtime is already started")
self._run_context = RunContext(self)
async def stop(self) -> None:
"""Stop the runtime message processing loop."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop()
self._run_context = None
async def stop_when_idle(self) -> None:
"""Stop the runtime message processing loop when there is
no outstanding message being processed or queued."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when_idle()
self._run_context = None
async def stop_when(self, condition: Callable[[], bool]) -> None:
"""Stop the runtime message processing loop when the condition is met."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when(condition)
self._run_context = None
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return (await self._get_agent(agent)).metadata
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return await (await self._get_agent(agent)).save_state()
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
await (await self._get_agent(agent)).load_state(state)
@deprecated(
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
)
async def register(
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
| list[Subscription]
| None = None,
) -> AgentType:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
if subscriptions is not None:
if callable(subscriptions):
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
subscriptions_list_result = subscriptions()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
else:
subscriptions_list = subscriptions
for subscription in subscriptions_list:
await self.add_subscription(subscription)
self._agent_factories[type] = agent_factory
return AgentType(type)
async def register_factory(
self,
*,
type: AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
if type.type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
agent_instance = await maybe_agent_instance
else:
agent_instance = maybe_agent_instance
if type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance
self._agent_factories[type.type] = factory_wrapper
return type
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
with AgentInstantiationContext.populate_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
warnings.warn(
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
return agent
async def _get_agent(self, agent_id: AgentId) -> Agent:
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if agent_id.type not in self._agent_factories:
raise LookupError(f"Agent with name {agent_id.type} not found.")
agent_factory = self._agent_factories[agent_id.type]
agent = await self._invoke_agent_factory(agent_factory, agent_id)
self._instantiated_agents[agent_id] = agent
return agent
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
if id.type not in self._agent_factories:
raise LookupError(f"Agent with name {id.type} not found.")
# TODO: check if remote
agent_instance = await self._get_agent(id)
if not isinstance(agent_instance, type):
raise TypeError(
f"Agent with name {id.type} is not of type {type.__name__}. It is of type {type_func_alias(agent_instance).__name__}"
)
return agent_instance
async def add_subscription(self, subscription: Subscription) -> None:
await self._subscription_manager.add_subscription(subscription)
async def remove_subscription(self, id: str) -> None:
await self._subscription_manager.remove_subscription(id)
async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
) -> AgentId:
return await get_impl(
id_or_type=id_or_type,
key=key,
lazy=lazy,
instance_getter=self._get_agent,
)
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)

View File

@@ -6,7 +6,8 @@ from opentelemetry.trace import SpanKind
from opentelemetry.util import types
from typing_extensions import NotRequired
from ... import AgentId, TopicId
from .._agent_id import AgentId
from .._topic import TopicId
from ._constants import NAMESPACE
logger = logging.getLogger("autogen_core")

View File

@@ -1,9 +1,6 @@
from collections.abc import Sequence
from types import NoneType, UnionType
from typing import Any, Optional, Tuple, Type, Union, get_args, get_origin
# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
ChannelArgumentType = Sequence[Tuple[str, Any]]
from typing import Any, Optional, Type, Union, get_args, get_origin
def is_union(t: object) -> bool:

View File

@@ -1,9 +0,0 @@
"""
The :mod:`autogen_core.application` module provides implementations of core components that are used to compose an application
"""
from ._single_threaded_agent_runtime import SingleThreadedAgentRuntime
from ._worker_runtime import WorkerAgentRuntime
from ._worker_runtime_host import WorkerAgentRuntimeHost
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "WorkerAgentRuntimeHost"]

View File

@@ -1,689 +1,10 @@
from __future__ import annotations
import asyncio
import inspect
import logging
import threading
import uuid
import warnings
from asyncio import CancelledError, Future, Task
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
from opentelemetry.trace import TracerProvider
from typing_extensions import deprecated
from autogen_core._serialization import MessageSerializer, SerializationRegistry
from .._single_threaded_agent_runtime import SingleThreadedAgentRuntime as SingleThreadedAgentRuntimeAlias
from .. import (
Agent,
AgentId,
AgentInstantiationContext,
AgentMetadata,
AgentRuntime,
AgentType,
CancellationToken,
MessageContext,
MessageHandlerContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
@deprecated(
"autogen_core.application.SingleThreadedAgentRuntime moved to autogen_core.SingleThreadedAgentRuntime. This alias will be removed in 0.4.0."
)
from ..base.intervention import DropMessage, InterventionHandler
from ..exceptions import MessageDroppedException
from ._helpers import SubscriptionManager, get_impl
from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
# We use a type parameter in some functions which shadows the built-in `type` function.
# This is a workaround to avoid shadowing the built-in `type` function.
type_func_alias = type
@dataclass(kw_only=True)
class PublishMessageEnvelope:
"""A message envelope for publishing messages to all agents that can handle
the message of the type T."""
message: Any
cancellation_token: CancellationToken
sender: AgentId | None
topic_id: TopicId
metadata: EnvelopeMetadata | None = None
message_id: str
@dataclass(kw_only=True)
class SendMessageEnvelope:
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""
message: Any
sender: AgentId | None
recipient: AgentId
future: Future[Any]
cancellation_token: CancellationToken
metadata: EnvelopeMetadata | None = None
@dataclass(kw_only=True)
class ResponseMessageEnvelope:
"""A message envelope for sending a response to a message."""
message: Any
future: Future[Any]
sender: AgentId
recipient: AgentId | None
metadata: EnvelopeMetadata | None = None
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
class Counter:
def __init__(self) -> None:
self._count: int = 0
self.threadLock = threading.Lock()
def increment(self) -> None:
self.threadLock.acquire()
self._count += 1
self.threadLock.release()
def get(self) -> int:
return self._count
def decrement(self) -> None:
self.threadLock.acquire()
self._count -= 1
self.threadLock.release()
class RunContext:
class RunState(Enum):
RUNNING = 0
CANCELLED = 1
UNTIL_IDLE = 2
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
self._run_state = RunContext.RunState.RUNNING
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
self._run_task = asyncio.create_task(self._run())
self._lock = asyncio.Lock()
async def _run(self) -> None:
while True:
async with self._lock:
if self._end_condition():
return
await self._runtime.process_next()
async def stop(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.CANCELLED
self._end_condition = self._stop_when_cancelled
await self._run_task
async def stop_when_idle(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.UNTIL_IDLE
self._end_condition = self._stop_when_idle
await self._run_task
async def stop_when(self, condition: Callable[[], bool]) -> None:
async with self._lock:
self._end_condition = condition
await self._run_task
def _stop_when_cancelled(self) -> bool:
return self._run_state == RunContext.RunState.CANCELLED
def _stop_when_idle(self) -> bool:
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
def _warn_if_none(value: Any, handler_name: str) -> None:
"""
Utility function to check if the intervention handler returned None and issue a warning.
Args:
value: The return value to check
handler_name: Name of the intervention handler method for the warning message
"""
if value is None:
warnings.warn(
f"Intervention handler {handler_name} returned None. This might be unintentional. "
"Consider returning the original message or DropMessage explicitly.",
RuntimeWarning,
stacklevel=2,
)
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(
self,
*,
intervention_handlers: List[InterventionHandler] | None = None,
tracer_provider: TracerProvider | None = None,
) -> None:
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
] = {}
self._instantiated_agents: Dict[AgentId, Agent] = {}
self._intervention_handlers = intervention_handlers
self._outstanding_tasks = Counter()
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
self._serialization_registry = SerializationRegistry()
@property
def unprocessed_messages(
self,
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
return self._message_queue
@property
def outstanding_tasks(self) -> int:
return self._outstanding_tasks.get()
@property
def _known_agent_names(self) -> Set[str]:
return set(self._agent_factories.keys())
# Returns the response of the message
async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
with self._tracer_helper.trace_block(
"create",
recipient,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
future = asyncio.get_event_loop().create_future()
if recipient.type not in self._known_agent_names:
future.set_exception(Exception("Recipient not found"))
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
metadata=get_telemetry_envelope_metadata(),
)
)
cancellation_token.link_future(future)
return await future
async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> None:
with self._tracer_helper.trace_block(
"create",
topic_id,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
if cancellation_token is None:
cancellation_token = CancellationToken()
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}")
if message_id is None:
message_id = str(uuid.uuid4())
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
self._message_queue.append(
PublishMessageEnvelope(
message=message,
cancellation_token=cancellation_token,
sender=sender,
topic_id=topic_id,
metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
)
)
async def save_state(self) -> Mapping[str, Any]:
state: Dict[str, Dict[str, Any]] = {}
for agent_id in self._instantiated_agents:
state[str(agent_id)] = dict(await (await self._get_agent(agent_id)).save_state())
return state
async def load_state(self, state: Mapping[str, Any]) -> None:
for agent_id_str in state:
agent_id = AgentId.from_str(agent_id_str)
if agent_id.type in self._known_agent_names:
await (await self._get_agent(agent_id)).load_state(state[str(agent_id)])
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
recipient = message_envelope.recipient
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
try:
# TODO use id
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
recipient_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
# Will be fixed when send API removed
message_id="NOT_DEFINED_TODO_FIX",
)
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
except CancelledError as e:
if not message_envelope.future.cancelled():
message_envelope.future.set_exception(e)
self._outstanding_tasks.decrement()
return
except BaseException as e:
message_envelope.future.set_exception(e)
self._outstanding_tasks.decrement()
return
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
metadata=get_telemetry_envelope_metadata(),
)
)
self._outstanding_tasks.decrement()
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
try:
responses: List[Awaitable[Any]] = []
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
for agent_id in recipients:
# Avoid sending the message back to the sender
if message_envelope.sender is not None and agent_id == message_envelope.sender:
continue
sender_agent = (
await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None
)
sender_name = str(sender_agent.id) if sender_agent is not None else "Unknown"
logger.info(
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=message_envelope.topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
agent = await self._get_agent(agent_id)
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
with self._tracer_helper.trace_block("process", agent.id, parent=None):
with MessageHandlerContext.populate_context(agent.id):
return await agent.on_message(
message_envelope.message,
ctx=message_context,
)
future = _on_message(agent, message_context)
responses.append(future)
await asyncio.gather(*responses)
except BaseException as e:
# Ignore cancelled errors from logs
if isinstance(e, CancelledError):
return
logger.error("Error processing publish message", exc_info=True)
finally:
self._outstanding_tasks.decrement()
# TODO if responses are given for a publish
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata):
content = (
message_envelope.message.__dict__
if hasattr(message_envelope.message, "__dict__")
else message_envelope.message
)
logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
self._outstanding_tasks.decrement()
if not message_envelope.future.cancelled():
message_envelope.future.set_result(message_envelope.message)
async def process_next(self) -> None:
"""Process the next message in the queue."""
if len(self._message_queue) == 0:
# Yield control to the event loop to allow other tasks to run
await asyncio.sleep(0)
return
message_envelope = self._message_queue.pop(0)
match message_envelope:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
task = asyncio.create_task(self._process_send(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case PublishMessageEnvelope(
message=message,
sender=sender,
):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_publish(message, sender=sender)
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
# TODO log message dropped
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
try:
temp_message = await handler.on_response(message, sender=sender, recipient=recipient)
_warn_if_none(temp_message, "on_response")
except BaseException as e:
# TODO: should we raise the exception to sender of the response instead?
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
self._outstanding_tasks.increment()
task = asyncio.create_task(self._process_response(message_envelope))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
@property
def idle(self) -> bool:
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
def start(self) -> None:
"""Start the runtime message processing loop."""
if self._run_context is not None:
raise RuntimeError("Runtime is already started")
self._run_context = RunContext(self)
async def stop(self) -> None:
"""Stop the runtime message processing loop."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop()
self._run_context = None
async def stop_when_idle(self) -> None:
"""Stop the runtime message processing loop when there is
no outstanding message being processed or queued."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when_idle()
self._run_context = None
async def stop_when(self, condition: Callable[[], bool]) -> None:
"""Stop the runtime message processing loop when the condition is met."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when(condition)
self._run_context = None
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return (await self._get_agent(agent)).metadata
async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
return await (await self._get_agent(agent)).save_state()
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
await (await self._get_agent(agent)).load_state(state)
@deprecated(
"Use your agent's `register` method directly instead of this method. See documentation for latest usage."
)
async def register(
self,
type: str,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]]
| list[Subscription]
| None = None,
) -> AgentType:
if type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
if subscriptions is not None:
if callable(subscriptions):
with SubscriptionInstantiationContext.populate_context(AgentType(type)):
subscriptions_list_result = subscriptions()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
else:
subscriptions_list = subscriptions
for subscription in subscriptions_list:
await self.add_subscription(subscription)
self._agent_factories[type] = agent_factory
return AgentType(type)
async def register_factory(
self,
*,
type: AgentType,
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
if type.type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")
async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
agent_instance = await maybe_agent_instance
else:
agent_instance = maybe_agent_instance
if type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance
self._agent_factories[type.type] = factory_wrapper
return type
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
agent_id: AgentId,
) -> T:
with AgentInstantiationContext.populate_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
elif len(inspect.signature(agent_factory).parameters) == 2:
warnings.warn(
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
return agent
async def _get_agent(self, agent_id: AgentId) -> Agent:
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
if agent_id.type not in self._agent_factories:
raise LookupError(f"Agent with name {agent_id.type} not found.")
agent_factory = self._agent_factories[agent_id.type]
agent = await self._invoke_agent_factory(agent_factory, agent_id)
self._instantiated_agents[agent_id] = agent
return agent
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
if id.type not in self._agent_factories:
raise LookupError(f"Agent with name {id.type} not found.")
# TODO: check if remote
agent_instance = await self._get_agent(id)
if not isinstance(agent_instance, type):
raise TypeError(
f"Agent with name {id.type} is not of type {type.__name__}. It is of type {type_func_alias(agent_instance).__name__}"
)
return agent_instance
async def add_subscription(self, subscription: Subscription) -> None:
await self._subscription_manager.add_subscription(subscription)
async def remove_subscription(self, id: str) -> None:
await self._subscription_manager.remove_subscription(id)
async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
) -> AgentId:
return await get_impl(
id_or_type=id_or_type,
key=key,
lazy=lazy,
instance_getter=self._get_agent,
)
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)
class SingleThreadedAgentRuntime(SingleThreadedAgentRuntimeAlias):
pass

View File

@@ -1,15 +1,9 @@
ROOT_LOGGER_NAME = "autogen_core"
"""str: Logger name used for structured event logging"""
"""Deprecated alias. Use autogen_core.ROOT_LOGGER_NAME"""
EVENT_LOGGER_NAME = "autogen_core.events"
"""str: Logger name used for structured event logging"""
"""Deprecated alias. Use autogen_core.EVENT_LOGGER_NAME"""
TRACE_LOGGER_NAME = "autogen_core.trace"
"""str: Logger name used for developer intended trace logging. The content and format of this log should not be depended upon."""
__all__ = [
"ROOT_LOGGER_NAME",
"EVENT_LOGGER_NAME",
"TRACE_LOGGER_NAME",
]
"""Deprecated alias. Use autogen_core.TRACE_LOGGER_NAME"""

View File

@@ -1,8 +0,0 @@
"""
The :mod:`autogen_core.worker.protos` module provides Google Protobuf classes for agent-worker communication
"""
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))

View File

@@ -1,75 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: agent_worker.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
import cloudevent_pb2 as cloudevent__pb2
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'agent_worker_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\036Microsoft.AutoGen.Abstractions'
_globals['_RPCREQUEST_METADATAENTRY']._options = None
_globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001'
_globals['_RPCRESPONSE_METADATAENTRY']._options = None
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001'
_globals['_EVENT_METADATAENTRY']._options = None
_globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001'
_globals['_TOPICID']._serialized_start=75
_globals['_TOPICID']._serialized_end=114
_globals['_AGENTID']._serialized_start=116
_globals['_AGENTID']._serialized_end=152
_globals['_PAYLOAD']._serialized_start=154
_globals['_PAYLOAD']._serialized_end=223
_globals['_RPCREQUEST']._serialized_start=226
_globals['_RPCREQUEST']._serialized_end=491
_globals['_RPCREQUEST_METADATAENTRY']._serialized_start=433
_globals['_RPCREQUEST_METADATAENTRY']._serialized_end=480
_globals['_RPCRESPONSE']._serialized_start=494
_globals['_RPCRESPONSE']._serialized_end=678
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=433
_globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=480
_globals['_EVENT']._serialized_start=681
_globals['_EVENT']._serialized_end=909
_globals['_EVENT_METADATAENTRY']._serialized_start=433
_globals['_EVENT_METADATAENTRY']._serialized_end=480
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911
_globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973
_globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067
_globals['_TYPESUBSCRIPTION']._serialized_start=1069
_globals['_TYPESUBSCRIPTION']._serialized_end=1127
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=1129
_globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=1200
_globals['_SUBSCRIPTION']._serialized_start=1203
_globals['_SUBSCRIPTION']._serialized_end=1353
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1355
_globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1443
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1445
_globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1537
_globals['_AGENTSTATE']._serialized_start=1540
_globals['_AGENTSTATE']._serialized_end=1697
_globals['_GETSTATERESPONSE']._serialized_start=1699
_globals['_GETSTATERESPONSE']._serialized_end=1805
_globals['_SAVESTATERESPONSE']._serialized_start=1807
_globals['_SAVESTATERESPONSE']._serialized_end=1873
_globals['_MESSAGE']._serialized_start=1876
_globals['_MESSAGE']._serialized_end=2298
_globals['_AGENTRPC']._serialized_start=2301
_globals['_AGENTRPC']._serialized_end=2479
# @@protoc_insertion_point(module_scope)

View File

@@ -1,132 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import agent_worker_pb2 as agent__worker__pb2
class AgentRpcStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.OpenChannel = channel.stream_stream(
'/agents.AgentRpc/OpenChannel',
request_serializer=agent__worker__pb2.Message.SerializeToString,
response_deserializer=agent__worker__pb2.Message.FromString,
)
self.GetState = channel.unary_unary(
'/agents.AgentRpc/GetState',
request_serializer=agent__worker__pb2.AgentId.SerializeToString,
response_deserializer=agent__worker__pb2.GetStateResponse.FromString,
)
self.SaveState = channel.unary_unary(
'/agents.AgentRpc/SaveState',
request_serializer=agent__worker__pb2.AgentState.SerializeToString,
response_deserializer=agent__worker__pb2.SaveStateResponse.FromString,
)
class AgentRpcServicer(object):
"""Missing associated documentation comment in .proto file."""
def OpenChannel(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetState(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SaveState(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AgentRpcServicer_to_server(servicer, server):
rpc_method_handlers = {
'OpenChannel': grpc.stream_stream_rpc_method_handler(
servicer.OpenChannel,
request_deserializer=agent__worker__pb2.Message.FromString,
response_serializer=agent__worker__pb2.Message.SerializeToString,
),
'GetState': grpc.unary_unary_rpc_method_handler(
servicer.GetState,
request_deserializer=agent__worker__pb2.AgentId.FromString,
response_serializer=agent__worker__pb2.GetStateResponse.SerializeToString,
),
'SaveState': grpc.unary_unary_rpc_method_handler(
servicer.SaveState,
request_deserializer=agent__worker__pb2.AgentState.FromString,
response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'agents.AgentRpc', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class AgentRpc(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def OpenChannel(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/agents.AgentRpc/OpenChannel',
agent__worker__pb2.Message.SerializeToString,
agent__worker__pb2.Message.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GetState(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/GetState',
agent__worker__pb2.AgentId.SerializeToString,
agent__worker__pb2.GetStateResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def SaveState(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/SaveState',
agent__worker__pb2.AgentState.SerializeToString,
agent__worker__pb2.SaveStateResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@@ -1,39 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: cloudevent.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63loudevent.proto\x12\ncloudevent\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xa4\x05\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12:\n\nattributes\x18\x05 \x03(\x0b\x32&.cloudevent.CloudEvent.AttributesEntry\x12\x36\n\x08metadata\x18\x06 \x03(\x0b\x32$.cloudevent.CloudEvent.MetadataEntry\x12\x17\n\x0f\x64\x61tacontenttype\x18\x07 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x08 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\t \x01(\tH\x00\x12*\n\nproto_data\x18\n \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1a\x62\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12>\n\x05value\x18\x02 \x01(\x0b\x32/.cloudevent.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'cloudevent_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\252\002\036Microsoft.AutoGen.Abstractions'
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._options = None
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_options = b'8\001'
_globals['_CLOUDEVENT_METADATAENTRY']._options = None
_globals['_CLOUDEVENT_METADATAENTRY']._serialized_options = b'8\001'
_globals['_CLOUDEVENT']._serialized_start=93
_globals['_CLOUDEVENT']._serialized_end=769
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_start=400
_globals['_CLOUDEVENT_ATTRIBUTESENTRY']._serialized_end=498
_globals['_CLOUDEVENT_METADATAENTRY']._serialized_start=500
_globals['_CLOUDEVENT_METADATAENTRY']._serialized_end=547
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_start=550
_globals['_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE']._serialized_end=761
# @@protoc_insertion_point(module_scope)

View File

@@ -1,7 +1,7 @@
import pytest
from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime
from autogen_test_utils import NoopAgent
from pytest_mock import MockerFixture
from test_utils import NoopAgent
@pytest.mark.asyncio

View File

@@ -8,9 +8,9 @@ from autogen_core import (
CancellationToken,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
message_handler,
)
from autogen_core.application import SingleThreadedAgentRuntime
@dataclass

View File

@@ -2,8 +2,14 @@ import asyncio
from dataclasses import dataclass
import pytest
from autogen_core import ClosureAgent, ClosureContext, DefaultSubscription, DefaultTopicId, MessageContext
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import (
ClosureAgent,
ClosureContext,
DefaultSubscription,
DefaultTopicId,
MessageContext,
SingleThreadedAgentRuntime,
)
@dataclass

View File

@@ -1,9 +1,8 @@
import pytest
from autogen_core import AgentId
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import AgentId, SingleThreadedAgentRuntime
from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage
from autogen_core.exceptions import MessageDroppedException
from test_utils import LoopbackAgent, MessageType
from autogen_test_utils import LoopbackAgent, MessageType
@pytest.mark.asyncio

View File

@@ -3,9 +3,18 @@ from dataclasses import dataclass
from typing import Callable, cast
import pytest
from autogen_core import AgentId, MessageContext, RoutedAgent, TopicId, TypeSubscription, event, message_handler, rpc
from autogen_core.application import SingleThreadedAgentRuntime
from test_utils import LoopbackAgent
from autogen_core import (
AgentId,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
TopicId,
TypeSubscription,
event,
message_handler,
rpc,
)
from autogen_test_utils import LoopbackAgent
@dataclass

View File

@@ -6,14 +6,13 @@ from autogen_core import (
AgentInstantiationContext,
AgentType,
DefaultTopicId,
SingleThreadedAgentRuntime,
TopicId,
TypeSubscription,
try_get_known_serializers_for_type,
type_subscription,
)
from autogen_core.application import SingleThreadedAgentRuntime
from opentelemetry.sdk.trace import TracerProvider
from test_utils import (
from autogen_test_utils import (
CascadingAgent,
CascadingMessageType,
LoopbackAgent,
@@ -21,7 +20,8 @@ from test_utils import (
MessageType,
NoopAgent,
)
from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
from autogen_test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
from opentelemetry.sdk.trace import TracerProvider
test_exporter = TestExporter()

View File

@@ -1,8 +1,7 @@
from typing import Any, Mapping
import pytest
from autogen_core import AgentId, BaseAgent, MessageContext
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import AgentId, BaseAgent, MessageContext, SingleThreadedAgentRuntime
class StatefulAgent(BaseAgent):

View File

@@ -1,8 +1,14 @@
import pytest
from autogen_core import AgentId, DefaultSubscription, DefaultTopicId, TopicId, TypeSubscription
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import (
AgentId,
DefaultSubscription,
DefaultTopicId,
SingleThreadedAgentRuntime,
TopicId,
TypeSubscription,
)
from autogen_core.exceptions import CantHandleException
from test_utils import LoopbackAgent, MessageType
from autogen_test_utils import LoopbackAgent, MessageType
def test_type_subscription_match() -> None:

View File

@@ -3,8 +3,7 @@ import json
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
import pytest
from autogen_core import AgentId, CancellationToken, FunctionCall
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import AgentId, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,

View File

@@ -41,21 +41,28 @@ video-surfer = [
"openai-whisper",
]
grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]
[tool.uv]
dev-dependencies = []
dev-dependencies = [
"autogen_test_utils"
]
[tool.ruff]
extend = "../../pyproject.toml"
include = ["src/**", "tests/*.py"]
exclude = ["src/autogen_ext/agents/web_surfer/*.js"]
exclude = ["src/autogen_ext/agents/web_surfer/*.js", "src/autogen_ext/runtimes/grpc/protos", "tests/protos"]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests"]
exclude = ["src/autogen_ext/runtimes/grpc/protos", "tests/protos"]
[tool.pytest.ini_options]
minversion = "6.0"
@@ -66,6 +73,7 @@ include = "../../shared_tasks.toml"
[tool.poe.tasks]
test = "pytest -n auto"
mypy = "mypy --config-file ../../pyproject.toml --exclude src/autogen_ext/runtimes/grpc/protos --exclude tests/protos src tests"
[tool.mypy]
[[tool.mypy.overrides]]

View File

@@ -22,12 +22,12 @@ from typing import (
import tiktoken
from autogen_core import (
EVENT_LOGGER_NAME,
TRACE_LOGGER_NAME,
CancellationToken,
FunctionCall,
Image,
)
from autogen_core.application.logging import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from autogen_core.application.logging.events import LLMCallEvent
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
@@ -42,6 +42,7 @@ from autogen_core.components.models import (
UserMessage,
)
from autogen_core.components.tools import Tool, ToolSchema
from autogen_core.logging import LLMCallEvent
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import (
ChatCompletion,

View File

@@ -0,0 +1,16 @@
from ._worker_runtime import GrpcWorkerAgentRuntime
from ._worker_runtime_host import GrpcWorkerAgentRuntimeHost
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
try:
import grpc # type: ignore
except ImportError as e:
raise ImportError(
"To use the GRPC runtime the grpc extra must be installed. Run `pip install autogen-ext[grpc]`"
) from e
__all__ = [
"GrpcWorkerAgentRuntime",
"GrpcWorkerAgentRuntimeHost",
"GrpcWorkerAgentRuntimeHostServicer",
]

View File

@@ -0,0 +1,4 @@
from typing import Any, Sequence, Tuple
# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
ChannelArgumentType = Sequence[Tuple[str, Any]]

View File

@@ -28,13 +28,9 @@ from typing import (
cast,
)
from google.protobuf import any_pb2
from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated
from autogen_core.application.protos import cloudevent_pb2
from .. import (
from autogen_core import (
JSON_DATA_CONTENT_TYPE,
PROTOBUF_DATA_CONTENT_TYPE,
Agent,
AgentId,
AgentInstantiationContext,
@@ -44,24 +40,26 @@ from .. import (
CancellationToken,
MessageContext,
MessageHandlerContext,
MessageSerializer,
Subscription,
SubscriptionInstantiationContext,
TopicId,
TypePrefixSubscription,
TypeSubscription,
)
from .._serialization import (
JSON_DATA_CONTENT_TYPE,
PROTOBUF_DATA_CONTENT_TYPE,
MessageSerializer,
from autogen_core._runtime_impl_helpers import SubscriptionManager, get_impl
from autogen_core._serialization import (
SerializationRegistry,
)
from .._type_helpers import ChannelArgumentType
from .._type_prefix_subscription import TypePrefixSubscription
from .._type_subscription import TypeSubscription
from autogen_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
from google.protobuf import any_pb2
from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated
from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._helpers import SubscriptionManager, get_impl
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
from ._type_helpers import ChannelArgumentType
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
try:
import grpc.aio
@@ -181,7 +179,7 @@ class HostConnection:
return await self._recv_queue.get()
class WorkerAgentRuntime(AgentRuntime):
class GrpcWorkerAgentRuntime(AgentRuntime):
def __init__(
self,
host_address: str,

View File

@@ -3,9 +3,9 @@ import logging
import signal
from typing import Optional, Sequence
from .._type_helpers import ChannelArgumentType
from ._constants import GRPC_IMPORT_ERROR_STR
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
from ._type_helpers import ChannelArgumentType
from ._worker_runtime_host_servicer import GrpcWorkerAgentRuntimeHostServicer
try:
import grpc
@@ -16,10 +16,10 @@ from .protos import agent_worker_pb2_grpc
logger = logging.getLogger("autogen_core")
class WorkerAgentRuntimeHost:
class GrpcWorkerAgentRuntimeHost:
def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None:
self._server = grpc.aio.server(options=extra_grpc_config)
self._servicer = WorkerAgentRuntimeHostServicer()
self._servicer = GrpcWorkerAgentRuntimeHostServicer()
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
self._server.add_insecure_port(address)
self._address = address

View File

@@ -4,10 +4,10 @@ from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set, cast
from .. import Subscription, TopicId, TypeSubscription
from .._type_prefix_subscription import TypePrefixSubscription
from autogen_core import Subscription, TopicId, TypePrefixSubscription, TypeSubscription
from autogen_core._runtime_impl_helpers import SubscriptionManager
from ._constants import GRPC_IMPORT_ERROR_STR
from ._helpers import SubscriptionManager
try:
import grpc
@@ -20,7 +20,7 @@ logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
class WorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer):
"""A gRPC servicer that hosts message delivery service for agents."""
def __init__(self) -> None:

View File

@@ -0,0 +1,8 @@
"""
The :mod:`autogen_ext.runtimes.grpc.protos` module provides Google Protobuf classes for agent-worker communication
"""
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: agent_worker.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
import cloudevent_pb2 as cloudevent__pb2
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "agent_worker_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals["DESCRIPTOR"]._options = None
_globals["DESCRIPTOR"]._serialized_options = b"\252\002\036Microsoft.AutoGen.Abstractions"
_globals["_RPCREQUEST_METADATAENTRY"]._options = None
_globals["_RPCREQUEST_METADATAENTRY"]._serialized_options = b"8\001"
_globals["_RPCRESPONSE_METADATAENTRY"]._options = None
_globals["_RPCRESPONSE_METADATAENTRY"]._serialized_options = b"8\001"
_globals["_EVENT_METADATAENTRY"]._options = None
_globals["_EVENT_METADATAENTRY"]._serialized_options = b"8\001"
_globals["_TOPICID"]._serialized_start = 75
_globals["_TOPICID"]._serialized_end = 114
_globals["_AGENTID"]._serialized_start = 116
_globals["_AGENTID"]._serialized_end = 152
_globals["_PAYLOAD"]._serialized_start = 154
_globals["_PAYLOAD"]._serialized_end = 223
_globals["_RPCREQUEST"]._serialized_start = 226
_globals["_RPCREQUEST"]._serialized_end = 491
_globals["_RPCREQUEST_METADATAENTRY"]._serialized_start = 433
_globals["_RPCREQUEST_METADATAENTRY"]._serialized_end = 480
_globals["_RPCRESPONSE"]._serialized_start = 494
_globals["_RPCRESPONSE"]._serialized_end = 678
_globals["_RPCRESPONSE_METADATAENTRY"]._serialized_start = 433
_globals["_RPCRESPONSE_METADATAENTRY"]._serialized_end = 480
_globals["_EVENT"]._serialized_start = 681
_globals["_EVENT"]._serialized_end = 909
_globals["_EVENT_METADATAENTRY"]._serialized_start = 433
_globals["_EVENT_METADATAENTRY"]._serialized_end = 480
_globals["_REGISTERAGENTTYPEREQUEST"]._serialized_start = 911
_globals["_REGISTERAGENTTYPEREQUEST"]._serialized_end = 971
_globals["_REGISTERAGENTTYPERESPONSE"]._serialized_start = 973
_globals["_REGISTERAGENTTYPERESPONSE"]._serialized_end = 1067
_globals["_TYPESUBSCRIPTION"]._serialized_start = 1069
_globals["_TYPESUBSCRIPTION"]._serialized_end = 1127
_globals["_TYPEPREFIXSUBSCRIPTION"]._serialized_start = 1129
_globals["_TYPEPREFIXSUBSCRIPTION"]._serialized_end = 1200
_globals["_SUBSCRIPTION"]._serialized_start = 1203
_globals["_SUBSCRIPTION"]._serialized_end = 1353
_globals["_ADDSUBSCRIPTIONREQUEST"]._serialized_start = 1355
_globals["_ADDSUBSCRIPTIONREQUEST"]._serialized_end = 1443
_globals["_ADDSUBSCRIPTIONRESPONSE"]._serialized_start = 1445
_globals["_ADDSUBSCRIPTIONRESPONSE"]._serialized_end = 1537
_globals["_AGENTSTATE"]._serialized_start = 1540
_globals["_AGENTSTATE"]._serialized_end = 1697
_globals["_GETSTATERESPONSE"]._serialized_start = 1699
_globals["_GETSTATERESPONSE"]._serialized_end = 1805
_globals["_SAVESTATERESPONSE"]._serialized_start = 1807
_globals["_SAVESTATERESPONSE"]._serialized_end = 1873
_globals["_MESSAGE"]._serialized_start = 1876
_globals["_MESSAGE"]._serialized_end = 2298
_globals["_AGENTRPC"]._serialized_start = 2301
_globals["_AGENTRPC"]._serialized_end = 2479
# @@protoc_insertion_point(module_scope)

View File

@@ -4,13 +4,14 @@ isort:skip_file
"""
import builtins
import cloudevent_pb2
import collections.abc
import typing
import cloudevent_pb2
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@@ -67,7 +68,12 @@ class Payload(google.protobuf.message.Message):
data_content_type: builtins.str = ...,
data: builtins.bytes = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ...
def ClearField(
self,
field_name: typing.Literal[
"data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"
],
) -> None: ...
global___Payload = Payload
@@ -117,8 +123,31 @@ class RpcRequest(google.protobuf.message.Message):
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ...
def HasField(
self,
field_name: typing.Literal[
"_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"_source",
b"_source",
"metadata",
b"metadata",
"method",
b"method",
"payload",
b"payload",
"request_id",
b"request_id",
"source",
b"source",
"target",
b"target",
],
) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
global___RpcRequest = RpcRequest
@@ -162,7 +191,12 @@ class RpcResponse(google.protobuf.message.Message):
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ...
def ClearField(
self,
field_name: typing.Literal[
"error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"
],
) -> None: ...
global___RpcResponse = RpcResponse
@@ -208,8 +242,26 @@ class Event(google.protobuf.message.Message):
payload: global___Payload | None = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ...
def HasField(
self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"_source",
b"_source",
"metadata",
b"metadata",
"payload",
b"payload",
"source",
b"source",
"topic_source",
b"topic_source",
"topic_type",
b"topic_type",
],
) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ...
global___Event = Event
@@ -250,7 +302,12 @@ class RegisterAgentTypeResponse(google.protobuf.message.Message):
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
def ClearField(
self,
field_name: typing.Literal[
"_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"
],
) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___RegisterAgentTypeResponse = RegisterAgentTypeResponse
@@ -269,7 +326,9 @@ class TypeSubscription(google.protobuf.message.Message):
topic_type: builtins.str = ...,
agent_type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]) -> None: ...
def ClearField(
self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type", b"topic_type"]
) -> None: ...
global___TypeSubscription = TypeSubscription
@@ -287,7 +346,9 @@ class TypePrefixSubscription(google.protobuf.message.Message):
topic_type_prefix: builtins.str = ...,
agent_type: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]) -> None: ...
def ClearField(
self, field_name: typing.Literal["agent_type", b"agent_type", "topic_type_prefix", b"topic_type_prefix"]
) -> None: ...
global___TypePrefixSubscription = TypePrefixSubscription
@@ -307,9 +368,31 @@ class Subscription(google.protobuf.message.Message):
typeSubscription: global___TypeSubscription | None = ...,
typePrefixSubscription: global___TypePrefixSubscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
def HasField(
self,
field_name: typing.Literal[
"subscription",
b"subscription",
"typePrefixSubscription",
b"typePrefixSubscription",
"typeSubscription",
b"typeSubscription",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"subscription",
b"subscription",
"typePrefixSubscription",
b"typePrefixSubscription",
"typeSubscription",
b"typeSubscription",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing.Literal["subscription", b"subscription"]
) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
global___Subscription = Subscription
@@ -329,7 +412,9 @@ class AddSubscriptionRequest(google.protobuf.message.Message):
subscription: global___Subscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["request_id", b"request_id", "subscription", b"subscription"]) -> None: ...
def ClearField(
self, field_name: typing.Literal["request_id", b"request_id", "subscription", b"subscription"]
) -> None: ...
global___AddSubscriptionRequest = AddSubscriptionRequest
@@ -351,7 +436,12 @@ class AddSubscriptionResponse(google.protobuf.message.Message):
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"]) -> None: ...
def ClearField(
self,
field_name: typing.Literal[
"_error", b"_error", "error", b"error", "request_id", b"request_id", "success", b"success"
],
) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___AddSubscriptionResponse = AddSubscriptionResponse
@@ -381,9 +471,41 @@ class AgentState(google.protobuf.message.Message):
text_data: builtins.str = ...,
proto_data: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "eTag", b"eTag", "proto_data", b"proto_data", "text_data", b"text_data"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
def HasField(
self,
field_name: typing.Literal[
"agent_id",
b"agent_id",
"binary_data",
b"binary_data",
"data",
b"data",
"proto_data",
b"proto_data",
"text_data",
b"text_data",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"agent_id",
b"agent_id",
"binary_data",
b"binary_data",
"data",
b"data",
"eTag",
b"eTag",
"proto_data",
b"proto_data",
"text_data",
b"text_data",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing.Literal["data", b"data"]
) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
global___AgentState = AgentState
@@ -405,8 +527,15 @@ class GetStateResponse(google.protobuf.message.Message):
success: builtins.bool = ...,
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "agent_state", b"agent_state", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "agent_state", b"agent_state", "error", b"error", "success", b"success"]) -> None: ...
def HasField(
self, field_name: typing.Literal["_error", b"_error", "agent_state", b"agent_state", "error", b"error"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"_error", b"_error", "agent_state", b"agent_state", "error", b"error", "success", b"success"
],
) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___GetStateResponse = GetStateResponse
@@ -426,7 +555,9 @@ class SaveStateResponse(google.protobuf.message.Message):
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "success", b"success"]) -> None: ...
def ClearField(
self, field_name: typing.Literal["_error", b"_error", "error", b"error", "success", b"success"]
) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___SaveStateResponse = SaveStateResponse
@@ -467,8 +598,61 @@ class Message(google.protobuf.message.Message):
addSubscriptionRequest: global___AddSubscriptionRequest | None = ...,
addSubscriptionResponse: global___AddSubscriptionResponse | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ...
def HasField(
self,
field_name: typing.Literal[
"addSubscriptionRequest",
b"addSubscriptionRequest",
"addSubscriptionResponse",
b"addSubscriptionResponse",
"cloudEvent",
b"cloudEvent",
"message",
b"message",
"registerAgentTypeRequest",
b"registerAgentTypeRequest",
"registerAgentTypeResponse",
b"registerAgentTypeResponse",
"request",
b"request",
"response",
b"response",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"addSubscriptionRequest",
b"addSubscriptionRequest",
"addSubscriptionResponse",
b"addSubscriptionResponse",
"cloudEvent",
b"cloudEvent",
"message",
b"message",
"registerAgentTypeRequest",
b"registerAgentTypeRequest",
"registerAgentTypeResponse",
b"registerAgentTypeResponse",
"request",
b"request",
"response",
b"response",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing.Literal["message", b"message"]
) -> (
typing.Literal[
"request",
"response",
"cloudEvent",
"registerAgentTypeRequest",
"registerAgentTypeResponse",
"addSubscriptionRequest",
"addSubscriptionResponse",
]
| None
): ...
global___Message = Message

View File

@@ -0,0 +1,167 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import agent_worker_pb2 as agent__worker__pb2
import grpc
class AgentRpcStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.OpenChannel = channel.stream_stream(
"/agents.AgentRpc/OpenChannel",
request_serializer=agent__worker__pb2.Message.SerializeToString,
response_deserializer=agent__worker__pb2.Message.FromString,
)
self.GetState = channel.unary_unary(
"/agents.AgentRpc/GetState",
request_serializer=agent__worker__pb2.AgentId.SerializeToString,
response_deserializer=agent__worker__pb2.GetStateResponse.FromString,
)
self.SaveState = channel.unary_unary(
"/agents.AgentRpc/SaveState",
request_serializer=agent__worker__pb2.AgentState.SerializeToString,
response_deserializer=agent__worker__pb2.SaveStateResponse.FromString,
)
class AgentRpcServicer(object):
"""Missing associated documentation comment in .proto file."""
def OpenChannel(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
def GetState(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
def SaveState(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
def add_AgentRpcServicer_to_server(servicer, server):
rpc_method_handlers = {
"OpenChannel": grpc.stream_stream_rpc_method_handler(
servicer.OpenChannel,
request_deserializer=agent__worker__pb2.Message.FromString,
response_serializer=agent__worker__pb2.Message.SerializeToString,
),
"GetState": grpc.unary_unary_rpc_method_handler(
servicer.GetState,
request_deserializer=agent__worker__pb2.AgentId.FromString,
response_serializer=agent__worker__pb2.GetStateResponse.SerializeToString,
),
"SaveState": grpc.unary_unary_rpc_method_handler(
servicer.SaveState,
request_deserializer=agent__worker__pb2.AgentState.FromString,
response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler("agents.AgentRpc", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class AgentRpc(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def OpenChannel(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
"/agents.AgentRpc/OpenChannel",
agent__worker__pb2.Message.SerializeToString,
agent__worker__pb2.Message.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod
def GetState(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_unary(
request,
target,
"/agents.AgentRpc/GetState",
agent__worker__pb2.AgentId.SerializeToString,
agent__worker__pb2.GetStateResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod
def SaveState(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_unary(
request,
target,
"/agents.AgentRpc/SaveState",
agent__worker__pb2.AgentState.SerializeToString,
agent__worker__pb2.SaveStateResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -4,16 +4,16 @@ isort:skip_file
"""
import abc
import agent_worker_pb2
import collections.abc
import typing
import agent_worker_pb2
import grpc
import grpc.aio
import typing
_T = typing.TypeVar("_T")
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
...
@@ -56,20 +56,26 @@ class AgentRpcServicer(metaclass=abc.ABCMeta):
self,
request_iterator: _MaybeAsyncIterator[agent_worker_pb2.Message],
context: _ServicerContext,
) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ...
) -> typing.Union[
collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]
]: ...
@abc.abstractmethod
def GetState(
self,
request: agent_worker_pb2.AgentId,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.GetStateResponse, collections.abc.Awaitable[agent_worker_pb2.GetStateResponse]]: ...
) -> typing.Union[
agent_worker_pb2.GetStateResponse, collections.abc.Awaitable[agent_worker_pb2.GetStateResponse]
]: ...
@abc.abstractmethod
def SaveState(
self,
request: agent_worker_pb2.AgentState,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]]: ...
) -> typing.Union[
agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]
]: ...
def add_AgentRpcServicer_to_server(servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...
def add_AgentRpcServicer_to_server(
servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]
) -> None: ...

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: cloudevent.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x10\x63loudevent.proto\x12\ncloudevent\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xa4\x05\n\nCloudEvent\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\x12\x14\n\x0cspec_version\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12:\n\nattributes\x18\x05 \x03(\x0b\x32&.cloudevent.CloudEvent.AttributesEntry\x12\x36\n\x08metadata\x18\x06 \x03(\x0b\x32$.cloudevent.CloudEvent.MetadataEntry\x12\x17\n\x0f\x64\x61tacontenttype\x18\x07 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x08 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\t \x01(\tH\x00\x12*\n\nproto_data\x18\n \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x1a\x62\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12>\n\x05value\x18\x02 \x01(\x0b\x32/.cloudevent.CloudEvent.CloudEventAttributeValue:\x02\x38\x01\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd3\x01\n\x18\x43loudEventAttributeValue\x12\x14\n\nce_boolean\x18\x01 \x01(\x08H\x00\x12\x14\n\nce_integer\x18\x02 \x01(\x05H\x00\x12\x13\n\tce_string\x18\x03 \x01(\tH\x00\x12\x12\n\x08\x63\x65_bytes\x18\x04 \x01(\x0cH\x00\x12\x10\n\x06\x63\x65_uri\x18\x05 \x01(\tH\x00\x12\x14\n\nce_uri_ref\x18\x06 \x01(\tH\x00\x12\x32\n\x0c\x63\x65_timestamp\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampH\x00\x42\x06\n\x04\x61ttrB\x06\n\x04\x64\x61taB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "cloudevent_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals["DESCRIPTOR"]._options = None
_globals["DESCRIPTOR"]._serialized_options = b"\252\002\036Microsoft.AutoGen.Abstractions"
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._options = None
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._serialized_options = b"8\001"
_globals["_CLOUDEVENT_METADATAENTRY"]._options = None
_globals["_CLOUDEVENT_METADATAENTRY"]._serialized_options = b"8\001"
_globals["_CLOUDEVENT"]._serialized_start = 93
_globals["_CLOUDEVENT"]._serialized_end = 769
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._serialized_start = 400
_globals["_CLOUDEVENT_ATTRIBUTESENTRY"]._serialized_end = 498
_globals["_CLOUDEVENT_METADATAENTRY"]._serialized_start = 500
_globals["_CLOUDEVENT_METADATAENTRY"]._serialized_end = 547
_globals["_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE"]._serialized_start = 550
_globals["_CLOUDEVENT_CLOUDEVENTATTRIBUTEVALUE"]._serialized_end = 761
# @@protoc_insertion_point(module_scope)

View File

@@ -5,12 +5,13 @@ isort:skip_file
import builtins
import collections.abc
import typing
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import google.protobuf.timestamp_pb2
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@@ -87,9 +88,54 @@ class CloudEvent(google.protobuf.message.Message):
ce_uri_ref: builtins.str = ...,
ce_timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["attr", b"attr", "ce_boolean", b"ce_boolean", "ce_bytes", b"ce_bytes", "ce_integer", b"ce_integer", "ce_string", b"ce_string", "ce_timestamp", b"ce_timestamp", "ce_uri", b"ce_uri", "ce_uri_ref", b"ce_uri_ref"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["attr", b"attr"]) -> typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"] | None: ...
def HasField(
self,
field_name: typing.Literal[
"attr",
b"attr",
"ce_boolean",
b"ce_boolean",
"ce_bytes",
b"ce_bytes",
"ce_integer",
b"ce_integer",
"ce_string",
b"ce_string",
"ce_timestamp",
b"ce_timestamp",
"ce_uri",
b"ce_uri",
"ce_uri_ref",
b"ce_uri_ref",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"attr",
b"attr",
"ce_boolean",
b"ce_boolean",
"ce_bytes",
b"ce_bytes",
"ce_integer",
b"ce_integer",
"ce_string",
b"ce_string",
"ce_timestamp",
b"ce_timestamp",
"ce_uri",
b"ce_uri",
"ce_uri_ref",
b"ce_uri_ref",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing.Literal["attr", b"attr"]
) -> (
typing.Literal["ce_boolean", "ce_integer", "ce_string", "ce_bytes", "ce_uri", "ce_uri_ref", "ce_timestamp"]
| None
): ...
ID_FIELD_NUMBER: builtins.int
SOURCE_FIELD_NUMBER: builtins.int
@@ -115,7 +161,9 @@ class CloudEvent(google.protobuf.message.Message):
binary_data: builtins.bytes
text_data: builtins.str
@property
def attributes(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]:
def attributes(
self,
) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___CloudEvent.CloudEventAttributeValue]:
"""Optional & Extension Attributes"""
@property
@@ -136,8 +184,41 @@ class CloudEvent(google.protobuf.message.Message):
text_data: builtins.str = ...,
proto_data: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["attributes", b"attributes", "binary_data", b"binary_data", "data", b"data", "datacontenttype", b"datacontenttype", "id", b"id", "metadata", b"metadata", "proto_data", b"proto_data", "source", b"source", "spec_version", b"spec_version", "text_data", b"text_data", "type", b"type"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
def HasField(
self,
field_name: typing.Literal[
"binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing.Literal[
"attributes",
b"attributes",
"binary_data",
b"binary_data",
"data",
b"data",
"datacontenttype",
b"datacontenttype",
"id",
b"id",
"metadata",
b"metadata",
"proto_data",
b"proto_data",
"source",
b"source",
"spec_version",
b"spec_version",
"text_data",
b"text_data",
"type",
b"type",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing.Literal["data", b"data"]
) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ...
global___CloudEvent = CloudEvent

View File

@@ -0,0 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

View File

@@ -0,0 +1,17 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import abc
import collections.abc
import typing
import grpc
import grpc.aio
_T = typing.TypeVar("_T")
class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type: ignore[misc, type-arg]
...

View File

@@ -3,8 +3,15 @@ from dataclasses import dataclass
from typing import List
import pytest
from autogen_core import AgentId, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import (
AgentId,
DefaultTopicId,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
default_subscription,
message_handler,
)
from autogen_core.components.models import ChatCompletionClient, CreateResult, SystemMessage, UserMessage
from autogen_ext.models import ReplayChatCompletionClient

View File

@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: serialization_test.proto
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18serialization_test.proto\x12\x06\x61gents"\x1f\n\x0cProtoMessage\x12\x0f\n\x07message\x18\x01 \x01(\t"L\n\x13NestingProtoMessage\x12\x0f\n\x07message\x18\x01 \x01(\t\x12$\n\x06nested\x18\x02 \x01(\x0b\x32\x14.agents.ProtoMessageb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "serialization_test_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_PROTOMESSAGE"]._serialized_start=36
_globals["_PROTOMESSAGE"]._serialized_end=67
_globals["_NESTINGPROTOMESSAGE"]._serialized_start=69
_globals["_NESTINGPROTOMESSAGE"]._serialized_end=145
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,46 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import google.protobuf.descriptor
import google.protobuf.message
import typing
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class ProtoMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MESSAGE_FIELD_NUMBER: builtins.int
message: builtins.str
def __init__(
self,
*,
message: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["message", b"message"]) -> None: ...
global___ProtoMessage = ProtoMessage
@typing.final
class NestingProtoMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MESSAGE_FIELD_NUMBER: builtins.int
NESTED_FIELD_NUMBER: builtins.int
message: builtins.str
@property
def nested(self) -> global___ProtoMessage: ...
def __init__(
self,
*,
message: builtins.str = ...,
nested: global___ProtoMessage | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["nested", b"nested"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["message", b"message", "nested", b"nested"]) -> None: ...
global___NestingProtoMessage = NestingProtoMessage

View File

@@ -19,9 +19,8 @@ from autogen_core import (
try_get_known_serializers_for_type,
type_subscription,
)
from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost
from protos.serialization_test_pb2 import ProtoMessage
from test_utils import (
from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime, GrpcWorkerAgentRuntimeHost
from autogen_test_utils import (
CascadingAgent,
CascadingMessageType,
ContentMessage,
@@ -30,15 +29,16 @@ from test_utils import (
MessageType,
NoopAgent,
)
from protos.serialization_test_pb2 import ProtoMessage
@pytest.mark.asyncio
async def test_agent_types_must_be_unique_single_worker() -> None:
host_address = "localhost:50051"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = WorkerAgentRuntime(host_address=host_address)
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
@@ -57,12 +57,12 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
@pytest.mark.asyncio
async def test_agent_types_must_be_unique_multiple_workers() -> None:
host_address = "localhost:50052"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
worker2 = WorkerAgentRuntime(host_address=host_address)
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker2.start()
await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
@@ -82,10 +82,10 @@ async def test_agent_types_must_be_unique_multiple_workers() -> None:
@pytest.mark.asyncio
async def test_register_receives_publish() -> None:
host_address = "localhost:50053"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
@@ -93,7 +93,7 @@ async def test_register_receives_publish() -> None:
)
await worker1.add_subscription(TypeSubscription("default", "name1"))
worker2 = WorkerAgentRuntime(host_address=host_address)
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker2.start()
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker2.register_factory(
@@ -127,9 +127,9 @@ async def test_register_receives_publish() -> None:
@pytest.mark.asyncio
async def test_register_receives_publish_cascade_single_worker() -> None:
host_address = "localhost:50054"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
num_agents = 5
@@ -164,7 +164,7 @@ async def test_register_receives_publish_cascade_single_worker() -> None:
async def test_register_receives_publish_cascade_multiple_workers() -> None:
logging.basicConfig(level=logging.DEBUG)
host_address = "localhost:50055"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
# TODO: Increasing num_initial_messages or max_round to 2 causes the test to fail.
@@ -176,16 +176,16 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)
# Run multiple workers one for each agent.
workers: List[WorkerAgentRuntime] = []
workers: List[GrpcWorkerAgentRuntime] = []
# Register agents
for i in range(num_agents):
runtime = WorkerAgentRuntime(host_address=host_address)
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
workers.append(runtime)
# Publish messages
publisher = WorkerAgentRuntime(host_address=host_address)
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
publisher.start()
for _ in range(num_initial_messages):
@@ -207,11 +207,11 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
@pytest.mark.asyncio
async def test_default_subscription() -> None:
host_address = "localhost:50056"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = WorkerAgentRuntime(host_address=host_address)
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
publisher = WorkerAgentRuntime(host_address=host_address)
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
@@ -241,11 +241,11 @@ async def test_default_subscription() -> None:
@pytest.mark.asyncio
async def test_default_subscription_other_source() -> None:
host_address = "localhost:50057"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
publisher = WorkerAgentRuntime(host_address=host_address)
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
@@ -275,11 +275,11 @@ async def test_default_subscription_other_source() -> None:
@pytest.mark.asyncio
async def test_type_subscription() -> None:
host_address = "localhost:50058"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = WorkerAgentRuntime(host_address=host_address)
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
publisher = WorkerAgentRuntime(host_address=host_address)
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
@@ -312,9 +312,9 @@ async def test_type_subscription() -> None:
@pytest.mark.asyncio
async def test_duplicate_subscription() -> None:
host_address = "localhost:50059"
host = WorkerAgentRuntimeHost(address=host_address)
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1_2 = WorkerAgentRuntime(host_address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
host.start()
try:
worker1.start()
@@ -343,10 +343,10 @@ async def test_duplicate_subscription() -> None:
@pytest.mark.asyncio
async def test_disconnected_agent() -> None:
host_address = "localhost:50060"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1_2 = WorkerAgentRuntime(host_address=host_address)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
# to some private properties. This needs to be updated once they are available publicly
@@ -421,13 +421,13 @@ class ProtoReceivingAgent(RoutedAgent):
@pytest.mark.asyncio
async def test_proto_payloads() -> None:
host_address = "localhost:50057"
host = WorkerAgentRuntimeHost(address=host_address)
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
receiver_runtime = WorkerAgentRuntime(
receiver_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
receiver_runtime.start()
publisher_runtime = WorkerAgentRuntime(
publisher_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
@@ -473,10 +473,10 @@ async def test_grpc_max_message_size() -> None:
("grpc.max_receive_message_length", new_max_size),
]
host_address = "localhost:50061"
host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
worker2 = WorkerAgentRuntime(host_address=host_address)
worker3 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
host = GrpcWorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
worker1 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker3 = GrpcWorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
try:
host.start()

View File

@@ -5,8 +5,7 @@ import asyncio
import logging
import os
from autogen_core import AgentId, AgentProxy
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import AgentId, AgentProxy, SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_core.components.code_executor import CodeBlock
from autogen_ext.code_executors import DockerCommandLineCodeExecutor

View File

@@ -7,8 +7,7 @@ round-robin orchestrator agent. The code snippets are executed inside a docker c
import asyncio
import logging
from autogen_core import AgentId, AgentProxy
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import AgentId, AgentProxy, SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_core.components.code_executor import CodeBlock
from autogen_ext.code_executors import DockerCommandLineCodeExecutor

View File

@@ -5,8 +5,7 @@ to write input or perform actions, orchestrated by an round-robin orchestrator a
import asyncio
import logging
from autogen_core import AgentId, AgentProxy
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core import AgentId, AgentProxy, SingleThreadedAgentRuntime
from autogen_core.application.logging import EVENT_LOGGER_NAME
from autogen_magentic_one.agents.file_surfer import FileSurfer
from autogen_magentic_one.agents.orchestrator import RoundRobinOrchestrator

Some files were not shown because too many files have changed in this diff Show More