From a52d3bab536b580b6c80d3315cedc46bd87fef74 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 23 Jul 2024 11:49:38 -0700 Subject: [PATCH] Agent factory can be async (#247) --- .../GAIA/Templates/TeamOne/scenario.py | 30 ++++--- .../HumanEval/Templates/TeamOne/scenario.py | 6 +- .../HumanEval/Templates/TwoAgents/scenario.py | 17 ++-- python/pyproject.toml | 2 +- .../common/patterns/_group_chat_utils.py | 20 +++-- .../common/patterns/_orchestrator_chat.py | 5 +- python/samples/core/inner_outer_direct.py | 4 +- python/samples/core/one_agent_direct.py | 2 +- python/samples/core/two_agents_pub_sub.py | 4 +- python/samples/demos/assistant.py | 10 +-- python/samples/demos/chat_room.py | 18 ++-- python/samples/demos/chess_game.py | 10 +-- python/samples/demos/illustrator_critics.py | 20 ++--- python/samples/demos/software_consultancy.py | 16 ++-- python/samples/marketing-agents/app.py | 8 +- python/samples/marketing-agents/test_usage.py | 2 +- python/samples/patterns/coder_executor.py | 6 +- python/samples/patterns/coder_reviewer.py | 4 +- python/samples/patterns/group_chat.py | 8 +- python/samples/patterns/mixture_of_agents.py | 8 +- python/samples/patterns/multi_agent_debate.py | 10 +-- .../tool-use/coding_one_agent_direct.py | 2 +- .../tool-use/coding_two_agent_pub_sub.py | 4 +- .../custom_function_tool_one_agent_direct.py | 2 +- .../_single_threaded_agent_runtime.py | 90 ++++++++++--------- .../src/agnext/components/_closure_agent.py | 4 +- python/src/agnext/core/__init__.py | 3 +- python/src/agnext/core/_agent_proxy.py | 12 +-- python/src/agnext/core/_agent_runtime.py | 74 ++++++++------- python/src/agnext/core/_base_agent.py | 4 +- python/src/agnext/worker/worker_runtime.py | 72 ++++++++------- python/teams/team-one/examples/example.py | 8 +- .../teams/team-one/examples/example_coder.py | 8 +- .../team-one/examples/example_file_surfer.py | 6 +- .../team-one/examples/example_reflexagents.py | 8 +- .../team-one/examples/example_userproxy.py | 6 +- .../team-one/examples/example_websurfer.py | 6 +- python/teams/team-one/pyproject.toml | 28 ++++-- .../src/team_one/agents/base_orchestrator.py | 2 +- .../multimodal_web_surfer/set_of_mark.py | 2 +- .../src/team_one/agents/orchestrator.py | 18 ++-- python/tests/test_base_agent.py | 4 +- python/tests/test_cancellation.py | 20 ++--- python/tests/test_closure_agent.py | 2 +- python/tests/test_intervention.py | 18 ++-- python/tests/test_runtime.py | 16 ++-- python/tests/test_state.py | 22 ++--- 47 files changed, 352 insertions(+), 299 deletions(-) diff --git a/python/benchmarks/GAIA/Templates/TeamOne/scenario.py b/python/benchmarks/GAIA/Templates/TeamOne/scenario.py index 507323d87..dee4dfd53 100644 --- a/python/benchmarks/GAIA/Templates/TeamOne/scenario.py +++ b/python/benchmarks/GAIA/Templates/TeamOne/scenario.py @@ -1,9 +1,8 @@ import asyncio import logging -import json import os -from typing import Any, Dict, List, Tuple, Union +from typing import List from agnext.application import SingleThreadedAgentRuntime from agnext.application.logging import EVENT_LOGGER_NAME from agnext.components.models import ( @@ -18,13 +17,17 @@ from agnext.application.logging import EVENT_LOGGER_NAME from team_one.markdown_browser import MarkdownConverter, UnsupportedFormatException from team_one.agents.coder import Coder, Executor from team_one.agents.orchestrator import LedgerOrchestrator -from team_one.messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage +from team_one.messages import BroadcastMessage from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer from team_one.agents.file_surfer import FileSurfer -from team_one.utils import LogHandler, message_content_to_str, create_completion_client_from_env +from team_one.utils import LogHandler, message_content_to_str + +import re + +from agnext.components.models import AssistantMessage -async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]): +async def response_preparer(task: str, source: str, client: ChatCompletionClient, transcript: List[LLMMessage]) -> str: messages: List[LLMMessage] = [ UserMessage( content=f"Earlier you were asked the following:\n\n{task}\n\nYour team then worked diligently to address that request. Here is a transcript of that conversation:", @@ -37,7 +40,8 @@ async def response_preparer(task: str, source: str, client: ChatCompletionClient messages.append( UserMessage( content = message_content_to_str(message.content), - source=message.source, + # TODO fix this -> remove type ignore + source=message.source, # type: ignore ) ) @@ -68,7 +72,7 @@ If you are asked for a comma separated list, apply the above rules depending on # No answer if "unable to determine" in response.content.lower(): messages.append( AssistantMessage(content=response.content, source="self" ) ) - messages.append( + messages.append( UserMessage( content= f""" I understand that a definitive answer could not be determined. Please make a well-informed EDUCATED GUESS based on the conversation. @@ -115,29 +119,29 @@ async def main() -> None: ) # Register agents. - coder = runtime.register_and_get_proxy( + coder = await runtime.register_and_get_proxy( "Coder", lambda: Coder(model_client=client), ) - executor = runtime.register_and_get_proxy( + executor = await runtime.register_and_get_proxy( "Executor", lambda: Executor( "A agent for executing code", executor=LocalCommandLineCodeExecutor() ), ) - file_surfer = runtime.register_and_get_proxy( + file_surfer = await runtime.register_and_get_proxy( "file_surfer", lambda: FileSurfer(model_client=client), ) - web_surfer = runtime.register_and_get_proxy( + web_surfer = await runtime.register_and_get_proxy( "WebSurfer", lambda: MultimodalWebSurfer(), # Configuration is set later by init() ) - orchestrator = runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator( + orchestrator = await runtime.register_and_get_proxy("orchestrator", lambda: LedgerOrchestrator( agents=[coder, executor, file_surfer, web_surfer], model_client=client, )) @@ -185,7 +189,7 @@ async def main() -> None: actual_orchestrator = runtime._get_agent(orchestrator.id) # type: ignore assert isinstance(actual_orchestrator, LedgerOrchestrator) transcript: List[LLMMessage] = actual_orchestrator._chat_history # type: ignore - print(await response_preparer(task=task, source=orchestrator.metadata["name"], client=client, transcript=transcript)) + print(await response_preparer(task=task, source=(await orchestrator.metadata)["name"], client=client, transcript=transcript)) diff --git a/python/benchmarks/HumanEval/Templates/TeamOne/scenario.py b/python/benchmarks/HumanEval/Templates/TeamOne/scenario.py index d37412d19..8f9924c53 100644 --- a/python/benchmarks/HumanEval/Templates/TeamOne/scenario.py +++ b/python/benchmarks/HumanEval/Templates/TeamOne/scenario.py @@ -34,18 +34,18 @@ async def main() -> None: ) # Register agents. - coder = runtime.register_and_get_proxy( + coder = await runtime.register_and_get_proxy( "Coder", lambda: Coder(model_client=client), ) - executor = runtime.register_and_get_proxy( + executor = await runtime.register_and_get_proxy( "Executor", lambda: Executor( "A agent for executing code", executor=LocalCommandLineCodeExecutor() ), ) - runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor])) + await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor])) prompt = "" with open("prompt.txt", "rt") as fh: diff --git a/python/benchmarks/HumanEval/Templates/TwoAgents/scenario.py b/python/benchmarks/HumanEval/Templates/TwoAgents/scenario.py index cbb1e2386..835d6c18d 100644 --- a/python/benchmarks/HumanEval/Templates/TwoAgents/scenario.py +++ b/python/benchmarks/HumanEval/Templates/TwoAgents/scenario.py @@ -1,12 +1,11 @@ import asyncio -import json import re import uuid from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, List, Union from agnext.application import SingleThreadedAgentRuntime -from agnext.components import FunctionCall, TypeRoutedAgent, message_handler +from agnext.components import TypeRoutedAgent, message_handler from agnext.components.code_executor import ( CodeBlock, CodeExecutor, @@ -16,16 +15,12 @@ from agnext.components.models import ( AssistantMessage, AzureOpenAIChatCompletionClient, ChatCompletionClient, - FunctionExecutionResult, - FunctionExecutionResultMessage, LLMMessage, ModelCapabilities, - OpenAIChatCompletionClient, SystemMessage, UserMessage, ) -from agnext.components.tools import CodeExecutionResult, PythonCodeExecutionTool -from agnext.core import AgentId, CancellationToken +from agnext.core import CancellationToken # from azure.identity import DefaultAzureCredential, get_bearer_token_provider @@ -66,7 +61,7 @@ if __name__ == "__main__": main() ``` -The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions. +The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions. Check the execution result returned by the user. If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes -- code blocks must stand alone and be ready to execute without modification. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, and think of a different approach to try. @@ -222,11 +217,11 @@ async def main() -> None: ) # Register agents. - coder = runtime.register_and_get( + coder = await runtime.register_and_get( "Coder", lambda: Coder(model_client=client), ) - runtime.register( + await runtime.register( "Executor", lambda: Executor( "A agent for executing code", executor=LocalCommandLineCodeExecutor() diff --git a/python/pyproject.toml b/python/pyproject.toml index 63a7880b7..f2e630aa7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "mypy==1.10.0", "ruff==0.4.8", "tiktoken", - "types-Pillow", + "types-pillow", "polars", "chess", "tavily-python", diff --git a/python/samples/common/patterns/_group_chat_utils.py b/python/samples/common/patterns/_group_chat_utils.py index ed46f3751..4b0d7ebda 100644 --- a/python/samples/common/patterns/_group_chat_utils.py +++ b/python/samples/common/patterns/_group_chat_utils.py @@ -22,10 +22,12 @@ async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClie history = "\n".join(history_messages) # Construct agent roles. - roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents]) + roles = "\n".join( + [f"{(await agent.metadata)['name']}: {(await agent.metadata)['description']}".strip() for agent in agents] + ) # Construct agent list. - participants = str([agent.metadata["name"] for agent in agents]) + participants = str([(await agent.metadata)["name"] for agent in agents]) # Select the next speaker. select_speaker_prompt = f"""You are in a role play game. The following roles are available: @@ -39,16 +41,22 @@ Read the above conversation. Then select the next role from {participants} to pl select_speaker_messages = [SystemMessage(select_speaker_prompt)] response = await client.create(messages=select_speaker_messages) assert isinstance(response.content, str) - mentions = mentioned_agents(response.content, agents) + mentions = await mentioned_agents(response.content, agents) if len(mentions) != 1: raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}") agent_name = list(mentions.keys())[0] - agent_index = next((i for i, agent in enumerate(agents) if agent.metadata["name"] == agent_name), None) + # Get the index of the selected agent by name + agent_index = 0 + for i, agent in enumerate(agents): + if (await agent.metadata)["name"] == agent_name: + agent_index = i + break + assert agent_index is not None return agent_index -def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]: +async def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]: """Counts the number of times each agent is mentioned in the provided message content. Agent names will match under any of the following conditions (all case-sensitive): - Exact name match @@ -66,7 +74,7 @@ def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str for agent in agents: # Finds agent mentions, taking word boundaries into account, # accommodates escaping underscores and underscores as spaces - name = agent.metadata["name"] + name = (await agent.metadata)["name"] regex = ( r"(?<=\W)(" + re.escape(name) diff --git a/python/samples/common/patterns/_orchestrator_chat.py b/python/samples/common/patterns/_orchestrator_chat.py index d59fd4819..862db4f15 100644 --- a/python/samples/common/patterns/_orchestrator_chat.py +++ b/python/samples/common/patterns/_orchestrator_chat.py @@ -170,7 +170,10 @@ Some additional points to consider: # A reusable description of the team. team = "\n".join( - [agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists] + [ + agent.name + ": " + (await self.runtime.agent_metadata(agent))["description"] + for agent in self._specialists + ] ) names = ", ".join([agent.name for agent in self._specialists]) diff --git a/python/samples/core/inner_outer_direct.py b/python/samples/core/inner_outer_direct.py index a1c12a412..d90ee0eef 100644 --- a/python/samples/core/inner_outer_direct.py +++ b/python/samples/core/inner_outer_direct.py @@ -45,8 +45,8 @@ class Outer(TypeRoutedAgent): async def main() -> None: runtime = SingleThreadedAgentRuntime() - inner = runtime.register_and_get("inner", Inner) - outer = runtime.register_and_get("outer", lambda: Outer(inner)) + inner = await runtime.register_and_get("inner", Inner) + outer = await runtime.register_and_get("outer", lambda: Outer(inner)) run_context = runtime.start() diff --git a/python/samples/core/one_agent_direct.py b/python/samples/core/one_agent_direct.py index ab87ffffd..2aa37a158 100644 --- a/python/samples/core/one_agent_direct.py +++ b/python/samples/core/one_agent_direct.py @@ -45,7 +45,7 @@ class ChatCompletionAgent(TypeRoutedAgent): async def main() -> None: runtime = SingleThreadedAgentRuntime() - agent = runtime.register_and_get( + agent = await runtime.register_and_get( "chat_agent", lambda: ChatCompletionAgent("Chat agent", get_chat_completion_client_from_envs(model="gpt-3.5-turbo")), ) diff --git a/python/samples/core/two_agents_pub_sub.py b/python/samples/core/two_agents_pub_sub.py index 96341524a..ffaba2c17 100644 --- a/python/samples/core/two_agents_pub_sub.py +++ b/python/samples/core/two_agents_pub_sub.py @@ -77,7 +77,7 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register the agents. - jack = runtime.register_and_get( + jack = await runtime.register_and_get( "Jack", lambda: ChatCompletionAgent( description="Jack a comedian", @@ -88,7 +88,7 @@ async def main() -> None: termination_word="TERMINATE", ), ) - runtime.register_and_get( + await runtime.register_and_get( "Cathy", lambda: ChatCompletionAgent( description="Cathy a poet", diff --git a/python/samples/demos/assistant.py b/python/samples/demos/assistant.py index 29b9b8288..781f7b5fc 100644 --- a/python/samples/demos/assistant.py +++ b/python/samples/demos/assistant.py @@ -166,7 +166,7 @@ class EventHandler(AsyncAssistantEventHandler): print("\n".join(citations)) -def assistant_chat(runtime: AgentRuntime) -> AgentId: +async def assistant_chat(runtime: AgentRuntime) -> AgentId: oai_assistant = openai.beta.assistants.create( model="gpt-4-turbo", description="An AI assistant that helps with everyday tasks.", @@ -177,7 +177,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId: thread = openai.beta.threads.create( tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, ) - assistant = runtime.register_and_get( + assistant = await runtime.register_and_get( "Assistant", lambda: OpenAIAssistantAgent( description="An AI assistant that helps with everyday tasks.", @@ -188,7 +188,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId: ), ) - user = runtime.register_and_get( + user = await runtime.register_and_get( "User", lambda: UserProxyAgent( client=openai.AsyncClient(), @@ -198,7 +198,7 @@ def assistant_chat(runtime: AgentRuntime) -> AgentId: ), ) # Create a group chat manager to facilitate a turn-based conversation. - runtime.register( + await runtime.register( "GroupChatManager", lambda: GroupChatManager( description="A group chat manager.", @@ -225,7 +225,7 @@ This will upload data.csv to the assistant for use with the code interpreter too Type "exit" to exit the chat. """ runtime = SingleThreadedAgentRuntime() - user = assistant_chat(runtime) + user = await assistant_chat(runtime) _run_context = runtime.start() print(usage) # Request the user to start the conversation. diff --git a/python/samples/demos/chat_room.py b/python/samples/demos/chat_room.py index eef00eabd..b3f8272a7 100644 --- a/python/samples/demos/chat_room.py +++ b/python/samples/demos/chat_room.py @@ -87,15 +87,15 @@ class ChatRoomUserAgent(TextualUserAgent): # Define a chat room with participants -- the runtime is the chat room. -def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: - runtime.register( +async def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: + await runtime.register( "User", lambda: ChatRoomUserAgent( description="The user in the chat room.", app=app, ), ) - alice = runtime.register_and_get_proxy( + alice = await runtime.register_and_get_proxy( "Alice", lambda rt, id: ChatRoomAgent( name=id.name, @@ -105,7 +105,7 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) - bob = runtime.register_and_get_proxy( + bob = await runtime.register_and_get_proxy( "Bob", lambda rt, id: ChatRoomAgent( name=id.name, @@ -115,7 +115,7 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) - charlie = runtime.register_and_get_proxy( + charlie = await runtime.register_and_get_proxy( "Charlie", lambda rt, id: ChatRoomAgent( name=id.name, @@ -126,9 +126,9 @@ def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: ), ) app.welcoming_notice = f"""Welcome to the chat room demo with the following participants: -1. 👧 {alice.id.name}: {alice.metadata['description']} -2. 👱🏼‍♂️ {bob.id.name}: {bob.metadata['description']} -3. 👨🏾‍🦳 {charlie.id.name}: {charlie.metadata['description']} +1. 👧 {alice.id.name}: {(await alice.metadata)['description']} +2. 👱🏼‍♂️ {bob.id.name}: {(await bob.metadata)['description']} +3. 👨🏾‍🦳 {charlie.id.name}: {(await charlie.metadata)['description']} Each participant decides on its own whether to respond to the latest message. @@ -139,7 +139,7 @@ You can greet the chat room by typing your first message below. async def main() -> None: runtime = SingleThreadedAgentRuntime() app = TextualChatApp(runtime, user_name="You") - chat_room(runtime, app) + await chat_room(runtime, app) _run_context = runtime.start() await app.run_async() diff --git a/python/samples/demos/chess_game.py b/python/samples/demos/chess_game.py index b08bf06b0..edbab5068 100644 --- a/python/samples/demos/chess_game.py +++ b/python/samples/demos/chess_game.py @@ -88,7 +88,7 @@ def make_move( return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[newMove.from_square]} to {SQUARE_NAMES[newMove.to_square]}." -def chess_game(runtime: AgentRuntime) -> None: # type: ignore +async def chess_game(runtime: AgentRuntime) -> None: # type: ignore """Create agents for a chess game and return the group chat.""" # Create the board. @@ -156,7 +156,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore ), ] - black = runtime.register_and_get( + black = await runtime.register_and_get( "PlayerBlack", lambda: ChatCompletionAgent( description="Player playing black.", @@ -173,7 +173,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore tools=black_tools, ), ) - white = runtime.register_and_get( + white = await runtime.register_and_get( "PlayerWhite", lambda: ChatCompletionAgent( description="Player playing white.", @@ -192,7 +192,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore ) # Create a group chat manager for the chess game to orchestrate a turn-based # conversation between the two agents. - runtime.register( + await runtime.register( "ChessGame", lambda: GroupChatManager( description="A chess game between two agents.", @@ -204,7 +204,7 @@ def chess_game(runtime: AgentRuntime) -> None: # type: ignore async def main() -> None: runtime = SingleThreadedAgentRuntime() - chess_game(runtime) + await chess_game(runtime) # Publish an initial message to trigger the group chat manager to start orchestration. await runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default") while True: diff --git a/python/samples/demos/illustrator_critics.py b/python/samples/demos/illustrator_critics.py index 75183f411..a0950bbb3 100644 --- a/python/samples/demos/illustrator_critics.py +++ b/python/samples/demos/illustrator_critics.py @@ -19,15 +19,15 @@ from common.utils import get_chat_completion_client_from_envs from utils import TextualChatApp, TextualUserAgent -def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None: - runtime.register( +async def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None: + await runtime.register( "User", lambda: TextualUserAgent( description="A user looking for illustration.", app=app, ), ) - descriptor = runtime.register_and_get_proxy( + descriptor = await runtime.register_and_get_proxy( "Descriptor", lambda: ChatCompletionAgent( description="An AI agent that provides a description of the image.", @@ -46,7 +46,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo", max_tokens=500), ), ) - illustrator = runtime.register_and_get_proxy( + illustrator = await runtime.register_and_get_proxy( "Illustrator", lambda: ImageGenerationAgent( description="An AI agent that generates images.", @@ -55,7 +55,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None: memory=BufferedChatMemory(buffer_size=1), ), ) - critic = runtime.register_and_get_proxy( + critic = await runtime.register_and_get_proxy( "Critic", lambda: ChatCompletionAgent( description="An AI agent that provides feedback on images given user's requirements.", @@ -74,7 +74,7 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None: model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"), ), ) - runtime.register( + await runtime.register( "GroupChatManager", lambda: GroupChatManager( description="A chat manager that handles group chat.", @@ -86,9 +86,9 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> None: app.welcoming_notice = f"""You are now in a group chat with the following agents: -1. 🤖 {descriptor.metadata['name']}: {descriptor.metadata.get('description')} -2. 🤖 {illustrator.metadata['name']}: {illustrator.metadata.get('description')} -3. 🤖 {critic.metadata['name']}: {critic.metadata.get('description')} +1. 🤖 {(await descriptor.metadata)['name']}: {(await descriptor.metadata).get('description')} +2. 🤖 {(await illustrator.metadata)['name']}: {(await illustrator.metadata).get('description')} +3. 🤖 {(await critic.metadata)['name']}: {(await critic.metadata).get('description')} Provide a prompt for the illustrator to generate an image. """ @@ -97,7 +97,7 @@ Provide a prompt for the illustrator to generate an image. async def main() -> None: runtime = SingleThreadedAgentRuntime() app = TextualChatApp(runtime, user_name="You") - illustrator_critics(runtime, app) + await illustrator_critics(runtime, app) _run_context = runtime.start() await app.run_async() diff --git a/python/samples/demos/software_consultancy.py b/python/samples/demos/software_consultancy.py index 28de35a41..b239fbca8 100644 --- a/python/samples/demos/software_consultancy.py +++ b/python/samples/demos/software_consultancy.py @@ -105,15 +105,15 @@ async def create_image( return f"Image created and saved to {filename}." -def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore - user_agent = runtime.register_and_get( +async def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore + user_agent = await runtime.register_and_get( "Customer", lambda: TextualUserAgent( description="A customer looking for help.", app=app, ), ) - developer = runtime.register_and_get( + developer = await runtime.register_and_get( "Developer", lambda: ChatCompletionAgent( description="A Python software developer.", @@ -153,7 +153,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # ), ) - product_manager = runtime.register_and_get( + product_manager = await runtime.register_and_get( "ProductManager", lambda: ChatCompletionAgent( description="A product manager. " @@ -182,7 +182,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # tool_approver=user_agent, ), ) - ux_designer = runtime.register_and_get( + ux_designer = await runtime.register_and_get( "UserExperienceDesigner", lambda: ChatCompletionAgent( description="A user experience designer for creating user interfaces.", @@ -215,7 +215,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # ), ) - illustrator = runtime.register_and_get( + illustrator = await runtime.register_and_get( "Illustrator", lambda: ChatCompletionAgent( description="An illustrator for creating images.", @@ -240,7 +240,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # tool_approver=user_agent, ), ) - runtime.register( + await runtime.register( "GroupChatManager", lambda: GroupChatManager( description="A group chat manager.", @@ -279,7 +279,7 @@ def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # async def main() -> None: runtime = SingleThreadedAgentRuntime() app = TextualChatApp(runtime, user_name="You") - software_consultancy(runtime, app) + await software_consultancy(runtime, app) # Start the runtime. _run_context = runtime.start() # Start the app. diff --git a/python/samples/marketing-agents/app.py b/python/samples/marketing-agents/app.py index 5a810aeae..5794ede92 100644 --- a/python/samples/marketing-agents/app.py +++ b/python/samples/marketing-agents/app.py @@ -27,8 +27,8 @@ async def build_app(runtime: AgentRuntime) -> None: api_version="2024-02-01", ) - runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client)) - runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client)) + await runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client)) + await runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client)) - runtime.get("GraphicDesigner") - runtime.get("Auditor") + await runtime.get("GraphicDesigner") + await runtime.get("Auditor") diff --git a/python/samples/marketing-agents/test_usage.py b/python/samples/marketing-agents/test_usage.py index 3ff179918..3d22ba456 100644 --- a/python/samples/marketing-agents/test_usage.py +++ b/python/samples/marketing-agents/test_usage.py @@ -30,7 +30,7 @@ class Printer(TypeRoutedAgent): async def main() -> None: runtime = SingleThreadedAgentRuntime() await build_app(runtime) - runtime.register("Printer", lambda: Printer()) + await runtime.register("Printer", lambda: Printer()) ctx = runtime.start() diff --git a/python/samples/patterns/coder_executor.py b/python/samples/patterns/coder_executor.py index a735313b5..5b77c8500 100644 --- a/python/samples/patterns/coder_executor.py +++ b/python/samples/patterns/coder_executor.py @@ -180,8 +180,10 @@ async def main(task: str, temp_dir: str) -> None: runtime = SingleThreadedAgentRuntime() # Register the agents. - runtime.register("coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo"))) - runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir))) + await runtime.register( + "coder", lambda: Coder(model_client=get_chat_completion_client_from_envs(model="gpt-4-turbo")) + ) + await runtime.register("executor", lambda: Executor(executor=LocalCommandLineCodeExecutor(work_dir=temp_dir))) run_context = runtime.start() # Publish the task message. diff --git a/python/samples/patterns/coder_reviewer.py b/python/samples/patterns/coder_reviewer.py index 9027e3e28..acddc5f38 100644 --- a/python/samples/patterns/coder_reviewer.py +++ b/python/samples/patterns/coder_reviewer.py @@ -251,14 +251,14 @@ Code: async def main() -> None: runtime = SingleThreadedAgentRuntime() - runtime.register( + await runtime.register( "ReviewerAgent", lambda: ReviewerAgent( description="Code Reviewer", model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), ), ) - runtime.register( + await runtime.register( "CoderAgent", lambda: CoderAgent( description="Coder", diff --git a/python/samples/patterns/group_chat.py b/python/samples/patterns/group_chat.py index f8cf43237..710a7c6ed 100644 --- a/python/samples/patterns/group_chat.py +++ b/python/samples/patterns/group_chat.py @@ -113,7 +113,7 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register the participants. - agent1 = runtime.register_and_get( + agent1 = await runtime.register_and_get( "DataScientist", lambda: GroupChatParticipant( description="A data scientist", @@ -121,7 +121,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), ), ) - agent2 = runtime.register_and_get( + agent2 = await runtime.register_and_get( "Engineer", lambda: GroupChatParticipant( description="An engineer", @@ -129,7 +129,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), ), ) - agent3 = runtime.register_and_get( + agent3 = await runtime.register_and_get( "Artist", lambda: GroupChatParticipant( description="An artist", @@ -139,7 +139,7 @@ async def main() -> None: ) # Register the group chat manager. - runtime.register( + await runtime.register( "GroupChatManager", lambda: RoundRobinGroupChatManager( description="A group chat manager", diff --git a/python/samples/patterns/mixture_of_agents.py b/python/samples/patterns/mixture_of_agents.py index 5b34886d7..4ef5e3e4a 100644 --- a/python/samples/patterns/mixture_of_agents.py +++ b/python/samples/patterns/mixture_of_agents.py @@ -112,7 +112,7 @@ class AggregatorAgent(TypeRoutedAgent): async def main() -> None: runtime = SingleThreadedAgentRuntime() # TODO: use different models for each agent. - runtime.register( + await runtime.register( "ReferenceAgent1", lambda: ReferenceAgent( description="Reference Agent 1", @@ -120,7 +120,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=0.1), ), ) - runtime.register( + await runtime.register( "ReferenceAgent2", lambda: ReferenceAgent( description="Reference Agent 2", @@ -128,7 +128,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=0.5), ), ) - runtime.register( + await runtime.register( "ReferenceAgent3", lambda: ReferenceAgent( description="Reference Agent 3", @@ -136,7 +136,7 @@ async def main() -> None: model_client=get_chat_completion_client_from_envs(model="gpt-3.5-turbo", temperature=1.0), ), ) - runtime.register( + await runtime.register( "AggregatorAgent", lambda: AggregatorAgent( description="Aggregator Agent", diff --git a/python/samples/patterns/multi_agent_debate.py b/python/samples/patterns/multi_agent_debate.py index 8a2551601..57d3baed7 100644 --- a/python/samples/patterns/multi_agent_debate.py +++ b/python/samples/patterns/multi_agent_debate.py @@ -211,7 +211,7 @@ async def main(question: str) -> None: # Register the solver agents. # Create a sparse connection: each solver agent has two neighbors. # NOTE: to create a dense connection, each solver agent should be connected to all other solver agents. - runtime.register( + await runtime.register( "MathSolver1", lambda: MathSolver( get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), @@ -219,7 +219,7 @@ async def main(question: str) -> None: max_round=3, ), ) - runtime.register( + await runtime.register( "MathSolver2", lambda: MathSolver( get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), @@ -227,7 +227,7 @@ async def main(question: str) -> None: max_round=3, ), ) - runtime.register( + await runtime.register( "MathSolver3", lambda: MathSolver( get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), @@ -235,7 +235,7 @@ async def main(question: str) -> None: max_round=3, ), ) - runtime.register( + await runtime.register( "MathSolver4", lambda: MathSolver( get_chat_completion_client_from_envs(model="gpt-3.5-turbo"), @@ -244,7 +244,7 @@ async def main(question: str) -> None: ), ) # Register the aggregator agent. - runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4)) + await runtime.register("MathAggregator", lambda: MathAggregator(num_solvers=4)) run_context = runtime.start() diff --git a/python/samples/tool-use/coding_one_agent_direct.py b/python/samples/tool-use/coding_one_agent_direct.py index 1c0d33996..db48d7e26 100644 --- a/python/samples/tool-use/coding_one_agent_direct.py +++ b/python/samples/tool-use/coding_one_agent_direct.py @@ -130,7 +130,7 @@ async def main() -> None: ) ] # Register agents. - tool_agent = runtime.register_and_get( + tool_agent = await runtime.register_and_get( "tool_enabled_agent", lambda: ToolEnabledAgent( description="Tool Use Agent", diff --git a/python/samples/tool-use/coding_two_agent_pub_sub.py b/python/samples/tool-use/coding_two_agent_pub_sub.py index 318c30bdb..c9ba45cb0 100644 --- a/python/samples/tool-use/coding_two_agent_pub_sub.py +++ b/python/samples/tool-use/coding_two_agent_pub_sub.py @@ -191,8 +191,8 @@ async def main() -> None: ) ] # Register agents. - runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools)) - runtime.register( + await runtime.register("tool_executor", lambda: ToolExecutorAgent("Tool Executor", tools)) + await runtime.register( "tool_use_agent", lambda: ToolUseAgent( description="Tool Use Agent", diff --git a/python/samples/tool-use/custom_function_tool_one_agent_direct.py b/python/samples/tool-use/custom_function_tool_one_agent_direct.py index ef93edca1..d6b9acb36 100644 --- a/python/samples/tool-use/custom_function_tool_one_agent_direct.py +++ b/python/samples/tool-use/custom_function_tool_one_agent_direct.py @@ -32,7 +32,7 @@ async def main() -> None: # Create the runtime. runtime = SingleThreadedAgentRuntime() # Register agents. - tool_agent = runtime.register_and_get( + tool_agent = await runtime.register_and_get( "tool_enabled_agent", lambda: ToolEnabledAgent( description="Tool Use Agent", diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index f4511ee3a..585a2b1ce 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -123,7 +123,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] # (namespace, type) -> List[AgentId] self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set) - self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {} + self._agent_factories: Dict[ + str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] + ] = {} self._instantiated_agents: Dict[AgentId, Agent] = {} self._intervention_handler = intervention_handler self._known_namespaces: set[str] = set() @@ -173,7 +175,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): if sender is not None and sender.namespace != recipient.namespace: raise ValueError("Sender and recipient must be in the same namespace to communicate.") - self._process_seen_namespace(recipient.namespace) + await self._process_seen_namespace(recipient.namespace) content = message.__dict__ if hasattr(message, "__dict__") else message logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {content}") @@ -227,7 +229,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): assert explicit_namespace is not None or sender_namespace is not None namespace = cast(str, explicit_namespace or sender_namespace) - self._process_seen_namespace(namespace) + await self._process_seen_namespace(namespace) self._message_queue.append( PublishMessageEnvelope( @@ -238,17 +240,17 @@ class SingleThreadedAgentRuntime(AgentRuntime): ) ) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: state: Dict[str, Dict[str, Any]] = {} for agent_id in self._instantiated_agents: - state[str(agent_id)] = dict(self._get_agent(agent_id).save_state()) + state[str(agent_id)] = dict((await self._get_agent(agent_id)).save_state()) return state - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: for agent_id_str in state: agent_id = AgentId.from_str(agent_id_str) if agent_id.name in self._known_agent_names: - self._get_agent(agent_id).load_state(state[str(agent_id)]) + (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: recipient = message_envelope.recipient @@ -269,7 +271,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): # delivery_stage=DeliveryStage.DELIVER, # ) # ) - recipient_agent = self._get_agent(recipient) + recipient_agent = await self._get_agent(recipient) response = await recipient_agent.on_message( message_envelope.message, cancellation_token=message_envelope.cancellation_token, @@ -297,7 +299,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name: continue - sender_agent = self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None + sender_agent = ( + await self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None + ) sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown" logger.info( f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}" @@ -312,7 +316,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): # ) # ) - agent = self._get_agent(agent_id) + agent = await self._get_agent(agent_id) future = agent.on_message( message_envelope.message, cancellation_token=message_envelope.cancellation_token, @@ -430,19 +434,19 @@ class SingleThreadedAgentRuntime(AgentRuntime): def start(self) -> RunContext: return RunContext(self) - def agent_metadata(self, agent: AgentId) -> AgentMetadata: - return self._get_agent(agent).metadata + async def agent_metadata(self, agent: AgentId) -> AgentMetadata: + return (await self._get_agent(agent)).metadata - def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: - return self._get_agent(agent).save_state() + async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: + return (await self._get_agent(agent)).save_state() - def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: - self._get_agent(agent).load_state(state) + async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: + (await self._get_agent(agent)).load_state(state) - def register( + async def register( self, name: str, - agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], ) -> None: if name in self._agent_factories: raise ValueError(f"Agent with name {name} already exists.") @@ -450,28 +454,30 @@ class SingleThreadedAgentRuntime(AgentRuntime): # For all already prepared namespaces we need to prepare this agent for namespace in self._known_namespaces: - self._get_agent(AgentId(name=name, namespace=namespace)) + await self._get_agent(AgentId(name=name, namespace=namespace)) - def _invoke_agent_factory( - self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId + async def _invoke_agent_factory( + self, + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], + agent_id: AgentId, ) -> T: - token = agent_instantiation_context.set((self, agent_id)) + with agent_instantiation_context((self, agent_id)): + if len(inspect.signature(agent_factory).parameters) == 0: + factory_one = cast(Callable[[], T], agent_factory) + agent = factory_one() + elif len(inspect.signature(agent_factory).parameters) == 2: + factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) + agent = factory_two(self, agent_id) + else: + raise ValueError("Agent factory must take 0 or 2 arguments.") - if len(inspect.signature(agent_factory).parameters) == 0: - factory_one = cast(Callable[[], T], agent_factory) - agent = factory_one() - elif len(inspect.signature(agent_factory).parameters) == 2: - factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) - agent = factory_two(self, agent_id) - else: - raise ValueError("Agent factory must take 0 or 2 arguments.") + if inspect.isawaitable(agent): + return cast(T, await agent) - agent_instantiation_context.reset(token) + return agent - return agent - - def _get_agent(self, agent_id: AgentId) -> Agent: - self._process_seen_namespace(agent_id.namespace) + async def _get_agent(self, agent_id: AgentId) -> Agent: + await self._process_seen_namespace(agent_id.namespace) if agent_id in self._instantiated_agents: return self._instantiated_agents[agent_id] @@ -480,25 +486,25 @@ class SingleThreadedAgentRuntime(AgentRuntime): agent_factory = self._agent_factories[agent_id.name] - agent = self._invoke_agent_factory(agent_factory, agent_id) + agent = await self._invoke_agent_factory(agent_factory, agent_id) for message_type in agent.metadata["subscriptions"]: self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id) self._instantiated_agents[agent_id] = agent return agent - def get(self, name: str, *, namespace: str = "default") -> AgentId: - return self._get_agent(AgentId(name=name, namespace=namespace)).id + async def get(self, name: str, *, namespace: str = "default") -> AgentId: + return (await self._get_agent(AgentId(name=name, namespace=namespace))).id - def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: - id = self.get(name, namespace=namespace) + async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: + id = await self.get(name, namespace=namespace) return AgentProxy(id, self) # Hydrate the agent instances in a namespace. The primary reason for this is # to ensure message type subscriptions are set up. - def _process_seen_namespace(self, namespace: str) -> None: + async def _process_seen_namespace(self, namespace: str) -> None: if namespace in self._known_namespaces: return self._known_namespaces.add(namespace) for name in self._known_agent_names: - self._get_agent(AgentId(name=name, namespace=namespace)) + await self._get_agent(AgentId(name=name, namespace=namespace)) diff --git a/python/src/agnext/components/_closure_agent.py b/python/src/agnext/components/_closure_agent.py index b2e9adec8..e0340a56e 100644 --- a/python/src/agnext/components/_closure_agent.py +++ b/python/src/agnext/components/_closure_agent.py @@ -4,7 +4,7 @@ from typing import Any, Awaitable, Callable, Mapping, Sequence, TypeVar, get_typ from ..core._agent import Agent from ..core._agent_id import AgentId from ..core._agent_metadata import AgentMetadata -from ..core._agent_runtime import AgentRuntime, agent_instantiation_context +from ..core._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime from ..core._cancellation_token import CancellationToken from ..core._serialization import MESSAGE_TYPE_REGISTRY from ..core.exceptions import CantHandleException @@ -46,7 +46,7 @@ class ClosureAgent(Agent): self, description: str, closure: Callable[[AgentRuntime, AgentId, T, CancellationToken], Awaitable[Any]] ) -> None: try: - runtime, id = agent_instantiation_context.get() + runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get() except LookupError as e: raise RuntimeError( "ClosureAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." diff --git a/python/src/agnext/core/__init__.py b/python/src/agnext/core/__init__.py index 98265e618..c89c10565 100644 --- a/python/src/agnext/core/__init__.py +++ b/python/src/agnext/core/__init__.py @@ -7,7 +7,7 @@ from ._agent_id import AgentId from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_proxy import AgentProxy -from ._agent_runtime import AgentRuntime, agent_instantiation_context +from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime, agent_instantiation_context from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken from ._serialization import MESSAGE_TYPE_REGISTRY, TypeDeserializer, TypeSerializer @@ -22,6 +22,7 @@ __all__ = [ "CancellationToken", "AgentChildren", "agent_instantiation_context", + "AGENT_INSTANTIATION_CONTEXT_VAR", "MESSAGE_TYPE_REGISTRY", "TypeSerializer", "TypeDeserializer", diff --git a/python/src/agnext/core/_agent_proxy.py b/python/src/agnext/core/_agent_proxy.py index 854376da8..f3eb70f28 100644 --- a/python/src/agnext/core/_agent_proxy.py +++ b/python/src/agnext/core/_agent_proxy.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING, Any, Awaitable, Mapping from ._agent_id import AgentId from ._agent_metadata import AgentMetadata @@ -21,7 +21,7 @@ class AgentProxy: return self._agent @property - def metadata(self) -> AgentMetadata: + def metadata(self) -> Awaitable[AgentMetadata]: """Metadata of the agent.""" return self._runtime.agent_metadata(self._agent) @@ -39,14 +39,14 @@ class AgentProxy: cancellation_token=cancellation_token, ) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: """Save the state of the agent. The result must be JSON serializable.""" - return self._runtime.agent_save_state(self._agent) + return await self._runtime.agent_save_state(self._agent) - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: """Load in the state of the agent obtained from `save_state`. Args: state (Mapping[str, Any]): State of the agent. Must be JSON serializable. """ - self._runtime.agent_load_state(self._agent, state) + await self._runtime.agent_load_state(self._agent, state) diff --git a/python/src/agnext/core/_agent_runtime.py b/python/src/agnext/core/_agent_runtime.py index a1054b83b..1d423180e 100644 --- a/python/src/agnext/core/_agent_runtime.py +++ b/python/src/agnext/core/_agent_runtime.py @@ -1,7 +1,8 @@ from __future__ import annotations +from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Callable, Mapping, Protocol, TypeVar, overload, runtime_checkable +from typing import Any, Awaitable, Callable, Generator, Mapping, Protocol, TypeVar, overload, runtime_checkable from ._agent import Agent from ._agent_id import AgentId @@ -13,7 +14,18 @@ from ._cancellation_token import CancellationToken T = TypeVar("T", bound=Agent) -agent_instantiation_context: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar("agent_instantiation_context") +AGENT_INSTANTIATION_CONTEXT_VAR: ContextVar[tuple[AgentRuntime, AgentId]] = ContextVar( + "AGENT_INSTANTIATION_CONTEXT_VAR" +) + + +@contextmanager +def agent_instantiation_context(ctx: tuple[AgentRuntime, AgentId]) -> Generator[None, Any, None]: + token = AGENT_INSTANTIATION_CONTEXT_VAR.set(ctx) + try: + yield + finally: + AGENT_INSTANTIATION_CONTEXT_VAR.reset(token) @runtime_checkable @@ -68,23 +80,23 @@ class AgentRuntime(Protocol): """ @overload - def register( + async def register( self, name: str, - agent_factory: Callable[[], T], + agent_factory: Callable[[], T | Awaitable[T]], ) -> None: ... @overload - def register( + async def register( self, name: str, - agent_factory: Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]], ) -> None: ... - def register( + async def register( self, name: str, - agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], ) -> None: """Register an agent factory with the runtime associated with a specific name. The name must be unique. @@ -110,7 +122,7 @@ class AgentRuntime(Protocol): ... - def get(self, name: str, *, namespace: str = "default") -> AgentId: + async def get(self, name: str, *, namespace: str = "default") -> AgentId: """Get an agent by name and namespace. Args: @@ -122,7 +134,7 @@ class AgentRuntime(Protocol): """ ... - def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: + async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: """Get a proxy for an agent by name and namespace. Args: @@ -135,27 +147,27 @@ class AgentRuntime(Protocol): ... @overload - def register_and_get( + async def register_and_get( self, name: str, - agent_factory: Callable[[], T], + agent_factory: Callable[[], T | Awaitable[T]], *, namespace: str = "default", ) -> AgentId: ... @overload - def register_and_get( + async def register_and_get( self, name: str, - agent_factory: Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]], *, namespace: str = "default", ) -> AgentId: ... - def register_and_get( + async def register_and_get( self, name: str, - agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], *, namespace: str = "default", ) -> AgentId: @@ -169,31 +181,31 @@ class AgentRuntime(Protocol): Returns: AgentId: The agent id. """ - self.register(name, agent_factory) - return self.get(name, namespace=namespace) + await self.register(name, agent_factory) + return await self.get(name, namespace=namespace) @overload - def register_and_get_proxy( + async def register_and_get_proxy( self, name: str, - agent_factory: Callable[[], T], + agent_factory: Callable[[], T | Awaitable[T]], *, namespace: str = "default", ) -> AgentProxy: ... @overload - def register_and_get_proxy( + async def register_and_get_proxy( self, name: str, - agent_factory: Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[AgentRuntime, AgentId], T | Awaitable[T]], *, namespace: str = "default", ) -> AgentProxy: ... - def register_and_get_proxy( + async def register_and_get_proxy( self, name: str, - agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], *, namespace: str = "default", ) -> AgentProxy: @@ -207,10 +219,10 @@ class AgentRuntime(Protocol): Returns: AgentProxy: The agent proxy. """ - self.register(name, agent_factory) - return self.get_proxy(name, namespace=namespace) + await self.register(name, agent_factory) + return await self.get_proxy(name, namespace=namespace) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: """Save the state of the entire runtime, including all hosted agents. The only way to restore the state is to pass it to :meth:`load_state`. The structure of the state is implementation defined and can be any JSON serializable object. @@ -220,7 +232,7 @@ class AgentRuntime(Protocol): """ ... - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: """Load the state of the entire runtime, including all hosted agents. The state should be the same as the one returned by :meth:`save_state`. Args: @@ -228,7 +240,7 @@ class AgentRuntime(Protocol): """ ... - def agent_metadata(self, agent: AgentId) -> AgentMetadata: + async def agent_metadata(self, agent: AgentId) -> AgentMetadata: """Get the metadata for an agent. Args: @@ -239,7 +251,7 @@ class AgentRuntime(Protocol): """ ... - def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: + async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: """Save the state of a single agent. The structure of the state is implementation defined and can be any JSON serializable object. @@ -252,7 +264,7 @@ class AgentRuntime(Protocol): """ ... - def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: + async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: """Load the state of a single agent. Args: diff --git a/python/src/agnext/core/_base_agent.py b/python/src/agnext/core/_base_agent.py index fdc6c681f..058002892 100644 --- a/python/src/agnext/core/_base_agent.py +++ b/python/src/agnext/core/_base_agent.py @@ -5,7 +5,7 @@ from typing import Any, Mapping, Sequence from ._agent import Agent from ._agent_id import AgentId from ._agent_metadata import AgentMetadata -from ._agent_runtime import AgentRuntime, agent_instantiation_context +from ._agent_runtime import AGENT_INSTANTIATION_CONTEXT_VAR, AgentRuntime from ._cancellation_token import CancellationToken @@ -22,7 +22,7 @@ class BaseAgent(ABC, Agent): def __init__(self, description: str, subscriptions: Sequence[str]) -> None: try: - runtime, id = agent_instantiation_context.get() + runtime, id = AGENT_INSTANTIATION_CONTEXT_VAR.get() except LookupError as e: raise RuntimeError( "BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated." diff --git a/python/src/agnext/worker/worker_runtime.py b/python/src/agnext/worker/worker_runtime.py index e00e5fe5a..e80102ce4 100644 --- a/python/src/agnext/worker/worker_runtime.py +++ b/python/src/agnext/worker/worker_runtime.py @@ -12,6 +12,7 @@ from typing import ( Any, AsyncIterable, AsyncIterator, + Awaitable, Callable, ClassVar, DefaultDict, @@ -188,7 +189,9 @@ class WorkerAgentRuntime(AgentRuntime): self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] # (namespace, type) -> List[AgentId] self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set) - self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {} + self._agent_factories: Dict[ + str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] + ] = {} # If empty, then all namespaces are valid for that agent type self._valid_namespaces: Dict[str, Sequence[str]] = {} self._instantiated_agents: Dict[AgentId, Agent] = {} @@ -249,7 +252,7 @@ class WorkerAgentRuntime(AgentRuntime): (namespace, MESSAGE_TYPE_REGISTRY.type_name(message)) ]: logger.info("Sending message to %s", agent_id) - agent = self._get_agent(agent_id) + agent = await self._get_agent(agent_id) try: await agent.on_message(message, CancellationToken()) logger.info("%s handled event %s", agent_id, message) @@ -321,7 +324,7 @@ class WorkerAgentRuntime(AgentRuntime): assert explicit_namespace is not None or sender_namespace is not None actual_namespace = cast(str, explicit_namespace or sender_namespace) - self._process_seen_namespace(actual_namespace) + await self._process_seen_namespace(actual_namespace) message_type = MESSAGE_TYPE_REGISTRY.type_name(message) serialized_message = MESSAGE_TYPE_REGISTRY.serialize(message, type_name=message_type) message = Message(event=Event(namespace=actual_namespace, type=message_type, data=serialized_message)) @@ -332,25 +335,25 @@ class WorkerAgentRuntime(AgentRuntime): await asyncio.create_task(write_message()) - def save_state(self) -> Mapping[str, Any]: + async def save_state(self) -> Mapping[str, Any]: raise NotImplementedError("Saving state is not yet implemented.") - def load_state(self, state: Mapping[str, Any]) -> None: + async def load_state(self, state: Mapping[str, Any]) -> None: raise NotImplementedError("Loading state is not yet implemented.") - def agent_metadata(self, agent: AgentId) -> AgentMetadata: + async def agent_metadata(self, agent: AgentId) -> AgentMetadata: raise NotImplementedError("Agent metadata is not yet implemented.") - def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: + async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: raise NotImplementedError("Agent save_state is not yet implemented.") - def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: + async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: raise NotImplementedError("Agent load_state is not yet implemented.") - def register( + async def register( self, name: str, - agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], ) -> None: if name in self._agent_factories: raise ValueError(f"Agent with name {name} already exists.") @@ -358,29 +361,32 @@ class WorkerAgentRuntime(AgentRuntime): # For all already prepared namespaces we need to prepare this agent for namespace in self._known_namespaces: - self._get_agent(AgentId(name=name, namespace=namespace)) + await self._get_agent(AgentId(name=name, namespace=namespace)) - # TODO do we need to convert register to async? - asyncio.create_task(self.send_register_agent_type(name)) + await self.send_register_agent_type(name) - def _invoke_agent_factory( + async def _invoke_agent_factory( self, - agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], agent_id: AgentId, ) -> T: - if len(inspect.signature(agent_factory).parameters) == 0: - factory_one = cast(Callable[[], T], agent_factory) - agent = factory_one() - elif len(inspect.signature(agent_factory).parameters) == 2: - factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) - agent = factory_two(self, agent_id) - else: - raise ValueError("Agent factory must take 0 or 2 arguments.") + with agent_instantiation_context((self, agent_id)): + if len(inspect.signature(agent_factory).parameters) == 0: + factory_one = cast(Callable[[], T], agent_factory) + agent = factory_one() + elif len(inspect.signature(agent_factory).parameters) == 2: + factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) + agent = factory_two(self, agent_id) + else: + raise ValueError("Agent factory must take 0 or 2 arguments.") + + if inspect.isawaitable(agent): + return cast(T, await agent) return agent - def _get_agent(self, agent_id: AgentId) -> Agent: - self._process_seen_namespace(agent_id.namespace) + async def _get_agent(self, agent_id: AgentId) -> Agent: + await self._process_seen_namespace(agent_id.namespace) if agent_id in self._instantiated_agents: return self._instantiated_agents[agent_id] @@ -389,9 +395,7 @@ class WorkerAgentRuntime(AgentRuntime): agent_factory = self._agent_factories[agent_id.name] - token = agent_instantiation_context.set((self, agent_id)) - agent = self._invoke_agent_factory(agent_factory, agent_id) - agent_instantiation_context.reset(token) + agent = await self._invoke_agent_factory(agent_factory, agent_id) for message_type in agent.metadata["subscriptions"]: self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id) @@ -399,19 +403,19 @@ class WorkerAgentRuntime(AgentRuntime): self._instantiated_agents[agent_id] = agent return agent - def get(self, name: str, *, namespace: str = "default") -> AgentId: - return self._get_agent(AgentId(name=name, namespace=namespace)).id + async def get(self, name: str, *, namespace: str = "default") -> AgentId: + return (await self._get_agent(AgentId(name=name, namespace=namespace))).id - def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: - id = self.get(name, namespace=namespace) + async def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: + id = await self.get(name, namespace=namespace) return AgentProxy(id, self) # Hydrate the agent instances in a namespace. The primary reason for this is # to ensure message type subscriptions are set up. - def _process_seen_namespace(self, namespace: str) -> None: + async def _process_seen_namespace(self, namespace: str) -> None: if namespace in self._known_namespaces: return self._known_namespaces.add(namespace) for name in self._known_agent_names: - self._get_agent(AgentId(name=name, namespace=namespace)) + await self._get_agent(AgentId(name=name, namespace=namespace)) diff --git a/python/teams/team-one/examples/example.py b/python/teams/team-one/examples/example.py index bdd0ca7ff..5bab59eca 100644 --- a/python/teams/team-one/examples/example.py +++ b/python/teams/team-one/examples/example.py @@ -15,19 +15,19 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register agents. - coder = runtime.register_and_get_proxy( + coder = await runtime.register_and_get_proxy( "Coder", lambda: Coder(model_client=create_completion_client_from_env()), ) - executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code")) + executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code")) - user_proxy = runtime.register_and_get_proxy( + user_proxy = await runtime.register_and_get_proxy( "UserProxy", lambda: UserProxy(description="The current user interacting with you."), ) - runtime.register( + await runtime.register( "orchestrator", lambda: LedgerOrchestrator( model_client=create_completion_client_from_env(), agents=[coder, executor, user_proxy] diff --git a/python/teams/team-one/examples/example_coder.py b/python/teams/team-one/examples/example_coder.py index 42778469b..d45db0999 100644 --- a/python/teams/team-one/examples/example_coder.py +++ b/python/teams/team-one/examples/example_coder.py @@ -15,19 +15,19 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() # Register agents. - coder = runtime.register_and_get_proxy( + coder = await runtime.register_and_get_proxy( "Coder", lambda: Coder(model_client=create_completion_client_from_env()), ) - executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code")) + executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code")) - user_proxy = runtime.register_and_get_proxy( + user_proxy = await runtime.register_and_get_proxy( "UserProxy", lambda: UserProxy(), ) - runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy])) + await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy])) run_context = runtime.start() await runtime.send_message(RequestReplyMessage(), user_proxy.id) diff --git a/python/teams/team-one/examples/example_file_surfer.py b/python/teams/team-one/examples/example_file_surfer.py index d26f6a927..c5d00c8cc 100644 --- a/python/teams/team-one/examples/example_file_surfer.py +++ b/python/teams/team-one/examples/example_file_surfer.py @@ -18,16 +18,16 @@ async def main() -> None: client = create_completion_client_from_env() # Register agents. - file_surfer = runtime.register_and_get_proxy( + file_surfer = await runtime.register_and_get_proxy( "file_surfer", lambda: FileSurfer(model_client=client), ) - user_proxy = runtime.register_and_get_proxy( + user_proxy = await runtime.register_and_get_proxy( "UserProxy", lambda: UserProxy(), ) - runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy])) + await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy])) run_context = runtime.start() await runtime.send_message(RequestReplyMessage(), user_proxy.id) diff --git a/python/teams/team-one/examples/example_reflexagents.py b/python/teams/team-one/examples/example_reflexagents.py index e0571492e..772c6ee84 100644 --- a/python/teams/team-one/examples/example_reflexagents.py +++ b/python/teams/team-one/examples/example_reflexagents.py @@ -13,10 +13,10 @@ from team_one.utils import LogHandler async def main() -> None: runtime = SingleThreadedAgentRuntime() - fake1 = runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent")) - fake2 = runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent")) - fake3 = runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent")) - runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3])) + fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent")) + fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent")) + fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent")) + await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3])) task_message = UserMessage(content="Test Message", source="User") run_context = runtime.start() diff --git a/python/teams/team-one/examples/example_userproxy.py b/python/teams/team-one/examples/example_userproxy.py index 304fac2af..9586f6b1a 100644 --- a/python/teams/team-one/examples/example_userproxy.py +++ b/python/teams/team-one/examples/example_userproxy.py @@ -19,16 +19,16 @@ async def main() -> None: client = create_completion_client_from_env() # Register agents. - coder = runtime.register_and_get_proxy( + coder = await runtime.register_and_get_proxy( "Coder", lambda: Coder(model_client=client), ) - user_proxy = runtime.register_and_get_proxy( + user_proxy = await runtime.register_and_get_proxy( "UserProxy", lambda: UserProxy(), ) - runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy])) + await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy])) run_context = runtime.start() await runtime.send_message(RequestReplyMessage(), user_proxy.id) diff --git a/python/teams/team-one/examples/example_websurfer.py b/python/teams/team-one/examples/example_websurfer.py index ff137a421..da75ece65 100644 --- a/python/teams/team-one/examples/example_websurfer.py +++ b/python/teams/team-one/examples/example_websurfer.py @@ -21,17 +21,17 @@ async def main() -> None: client = create_completion_client_from_env() # Register agents. - web_surfer = runtime.register_and_get_proxy( + web_surfer = await runtime.register_and_get_proxy( "WebSurfer", lambda: MultimodalWebSurfer(), ) - user_proxy = runtime.register_and_get_proxy( + user_proxy = await runtime.register_and_get_proxy( "UserProxy", lambda: UserProxy(), ) - runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy])) + await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy])) run_context = runtime.start() diff --git a/python/teams/team-one/pyproject.toml b/python/teams/team-one/pyproject.toml index 21a6b0bf3..1a606ae6f 100644 --- a/python/teams/team-one/pyproject.toml +++ b/python/teams/team-one/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "youtube-transcript-api", "SpeechRecognition", "pathvalidate", - "playwright" + "playwright", ] [tool.hatch.envs.default] @@ -45,7 +45,8 @@ dependencies = [ "aiofiles", "types-aiofiles", "types-requests", - "azure-identity" + "types-pillow", + "azure-identity", ] [tool.hatch.envs.default.extra-scripts] @@ -71,7 +72,13 @@ line-length = 120 fix = true exclude = ["build", "dist", "page_script.js"] target-version = "py310" -include = ["src/**", "examples/*.py"] +include = [ + "src/**", + "examples/*.py", + "../../benchmarks/HumanEval/Templates/TeamOne/scenario.py", + "../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py", + "../../benchmarks/GAIA/TeamOne/TwoAgents/scenario.py", +] [tool.ruff.format] docstring-code-format = true @@ -81,7 +88,11 @@ select = ["E", "F", "W", "B", "Q", "I", "ASYNC"] ignore = ["F401", "E501"] [tool.mypy] -files = ["src", "examples", "tests"] +files = [ + "src", + "tests", + "examples", +] strict = true python_version = "3.10" @@ -100,7 +111,14 @@ disallow_untyped_decorators = true disallow_any_unimported = true [tool.pyright] -include = ["src", "tests", "examples"] +include = [ + "src", + "tests", + "examples", + "../../benchmarks/HumanEval/Templates/TeamOne/scenario.py", + "../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py", + "../../benchmarks/GAIA/Templates/TeamOne/scenario.py", +] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false diff --git a/python/teams/team-one/src/team_one/agents/base_orchestrator.py b/python/teams/team-one/src/team_one/agents/base_orchestrator.py index f7e0c6d88..d8f903386 100644 --- a/python/teams/team-one/src/team_one/agents/base_orchestrator.py +++ b/python/teams/team-one/src/team_one/agents/base_orchestrator.py @@ -69,7 +69,7 @@ class BaseOrchestrator(TypeRoutedAgent): logger.info( OrchestrationEvent( source=f"{self.metadata['name']} (thought)", - message=f"Next speaker {next_agent.metadata['name']}" "", + message=f"Next speaker {(await next_agent.metadata)['name']}" "", ) ) diff --git a/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/set_of_mark.py b/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/set_of_mark.py index 892b85b8e..29e190c28 100644 --- a/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/set_of_mark.py +++ b/python/teams/team-one/src/team_one/agents/multimodal_web_surfer/set_of_mark.py @@ -68,7 +68,7 @@ def _draw_roi( luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11 text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255) - roi = [(rect["left"], rect["top"]), (rect["right"], rect["bottom"])] + roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"])) label_location = (rect["right"], rect["top"]) label_anchor = "rb" diff --git a/python/teams/team-one/src/team_one/agents/orchestrator.py b/python/teams/team-one/src/team_one/agents/orchestrator.py index ef77aaecb..7ca1d25e6 100644 --- a/python/teams/team-one/src/team_one/agents/orchestrator.py +++ b/python/teams/team-one/src/team_one/agents/orchestrator.py @@ -79,16 +79,16 @@ class LedgerOrchestrator(BaseOrchestrator): def _get_ledger_prompt(self, task: str, team: str, names: List[str]) -> str: return self._ledger_prompt.format(task=task, team=team, names=names) - def _get_team_description(self) -> str: + async def _get_team_description(self) -> str: team_description = "" for agent in self._agents: - name = agent.metadata["name"] - description = agent.metadata["description"] + name = (await agent.metadata)["name"] + description = (await agent.metadata)["description"] team_description += f"{name}: {description}\n" return team_description - def _get_team_names(self) -> List[str]: - return [agent.metadata["name"] for agent in self._agents] + async def _get_team_names(self) -> List[str]: + return [(await agent.metadata)["name"] for agent in self._agents] def _set_task_str(self, message: LLMMessage) -> None: if len(self._chat_history) == 1: @@ -112,7 +112,7 @@ class LedgerOrchestrator(BaseOrchestrator): return False async def _plan(self) -> str: - team_description = self._get_team_description() + team_description = await self._get_team_description() # 1. GATHER FACTS # create a closed book task and generate a response and update the chat history @@ -144,8 +144,8 @@ class LedgerOrchestrator(BaseOrchestrator): async def update_ledger(self) -> Dict[str, Any]: max_json_retries = 10 - team_description = self._get_team_description() - names = self._get_team_names() + team_description = await self._get_team_description() + names = await self._get_team_names() ledger_prompt = self._get_ledger_prompt(self.task_str, team_description, names) ledger_user_message = UserMessage(content=ledger_prompt, source=self.metadata["name"]) @@ -234,7 +234,7 @@ class LedgerOrchestrator(BaseOrchestrator): next_agent_name = ledger_dict["next_speaker"]["answer"] for agent in self._agents: - if agent.metadata["name"] == next_agent_name: + if (await agent.metadata)["name"] == next_agent_name: # broadcast a new message instruction = ledger_dict["instruction_or_question"]["answer"] user_message = UserMessage(content=instruction, source=self.metadata["name"]) diff --git a/python/tests/test_base_agent.py b/python/tests/test_base_agent.py index 81e364ced..41d414254 100644 --- a/python/tests/test_base_agent.py +++ b/python/tests/test_base_agent.py @@ -1,6 +1,6 @@ import pytest from pytest_mock import MockerFixture -from agnext.core import AgentRuntime, agent_instantiation_context, AgentId +from agnext.core import AgentRuntime, AGENT_INSTANTIATION_CONTEXT_VAR, AgentId from test_utils import NoopAgent @@ -11,7 +11,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None: runtime = mocker.Mock(spec=AgentRuntime) # Shows how to set the context for the agent instantiation in a test context - agent_instantiation_context.set((runtime, AgentId("name", "namespace"))) + AGENT_INSTANTIATION_CONTEXT_VAR.set((runtime, AgentId("name", "namespace"))) agent = NoopAgent() assert agent.runtime == runtime diff --git a/python/tests/test_cancellation.py b/python/tests/test_cancellation.py index 4357493ea..f383d8de7 100644 --- a/python/tests/test_cancellation.py +++ b/python/tests/test_cancellation.py @@ -57,7 +57,7 @@ class NestingLongRunningAgent(TypeRoutedAgent): async def test_cancellation_with_token() -> None: runtime = SingleThreadedAgentRuntime() - long_running = runtime.register_and_get("long_running", LongRunningAgent) + long_running = await runtime.register_and_get("long_running", LongRunningAgent) token = CancellationToken() response = asyncio.create_task(runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token)) assert not response.done() @@ -73,7 +73,7 @@ async def test_cancellation_with_token() -> None: await response assert response.done() - long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore + long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore assert long_running_agent.called assert long_running_agent.cancelled @@ -83,8 +83,8 @@ async def test_cancellation_with_token() -> None: async def test_nested_cancellation_only_outer_called() -> None: runtime = SingleThreadedAgentRuntime() - long_running = runtime.register_and_get("long_running", LongRunningAgent) - nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) + long_running = await runtime.register_and_get("long_running", LongRunningAgent) + nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) token = CancellationToken() response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token)) @@ -100,10 +100,10 @@ async def test_nested_cancellation_only_outer_called() -> None: await response assert response.done() - nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore + nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore assert nested_agent.called assert nested_agent.cancelled - long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore + long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore assert long_running_agent.called is False assert long_running_agent.cancelled is False @@ -111,8 +111,8 @@ async def test_nested_cancellation_only_outer_called() -> None: async def test_nested_cancellation_inner_called() -> None: runtime = SingleThreadedAgentRuntime() - long_running = runtime.register_and_get("long_running", LongRunningAgent ) - nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) + long_running = await runtime.register_and_get("long_running", LongRunningAgent ) + nested = await runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) token = CancellationToken() response = asyncio.create_task(runtime.send_message(MessageType(), nested, cancellation_token=token)) @@ -130,9 +130,9 @@ async def test_nested_cancellation_inner_called() -> None: await response assert response.done() - nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore + nested_agent: NestingLongRunningAgent = await runtime._get_agent(nested) # type: ignore assert nested_agent.called assert nested_agent.cancelled - long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore + long_running_agent: LongRunningAgent = await runtime._get_agent(long_running) # type: ignore assert long_running_agent.called assert long_running_agent.cancelled diff --git a/python/tests/test_closure_agent.py b/python/tests/test_closure_agent.py index 7baee6a7b..4543d8fe2 100644 --- a/python/tests/test_closure_agent.py +++ b/python/tests/test_closure_agent.py @@ -28,7 +28,7 @@ async def test_register_receives_publish() -> None: namespace = id.namespace await queue.put((namespace, message.content)) - runtime.register("name", lambda: ClosureAgent("My agent", log_message)) + await runtime.register("name", lambda: ClosureAgent("My agent", log_message)) run_context = runtime.start() await runtime.publish_message(Message("first message"), namespace="default") await runtime.publish_message(Message("second message"), namespace="default") diff --git a/python/tests/test_intervention.py b/python/tests/test_intervention.py index 687a2372b..395305b22 100644 --- a/python/tests/test_intervention.py +++ b/python/tests/test_intervention.py @@ -19,7 +19,7 @@ async def test_intervention_count_messages() -> None: handler = DebugInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - loopback = runtime.register_and_get("name", LoopbackAgent) + loopback = await runtime.register_and_get("name", LoopbackAgent) run_context = runtime.start() _response = await runtime.send_message(MessageType(), recipient=loopback) @@ -27,7 +27,7 @@ async def test_intervention_count_messages() -> None: await run_context.stop() assert handler.num_messages == 1 - loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore + loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore assert loopback_agent.num_calls == 1 @pytest.mark.asyncio @@ -40,7 +40,7 @@ async def test_intervention_drop_send() -> None: handler = DropSendInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - loopback = runtime.register_and_get("name", LoopbackAgent) + loopback = await runtime.register_and_get("name", LoopbackAgent) run_context = runtime.start() with pytest.raises(MessageDroppedException): @@ -48,7 +48,7 @@ async def test_intervention_drop_send() -> None: await run_context.stop() - loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore + loopback_agent: LoopbackAgent = await runtime._get_agent(loopback) # type: ignore assert loopback_agent.num_calls == 0 @@ -62,7 +62,7 @@ async def test_intervention_drop_response() -> None: handler = DropResponseInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - loopback = runtime.register_and_get("name", LoopbackAgent) + loopback = await runtime.register_and_get("name", LoopbackAgent) run_context = runtime.start() with pytest.raises(MessageDroppedException): @@ -84,7 +84,7 @@ async def test_intervention_raise_exception_on_send() -> None: handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - long_running = runtime.register_and_get("name", LoopbackAgent) + long_running = await runtime.register_and_get("name", LoopbackAgent) run_context = runtime.start() with pytest.raises(InterventionException): @@ -92,7 +92,7 @@ async def test_intervention_raise_exception_on_send() -> None: await run_context.stop() - long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore + long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore assert long_running_agent.num_calls == 0 @pytest.mark.asyncio @@ -108,12 +108,12 @@ async def test_intervention_raise_exception_on_respond() -> None: handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handler=handler) - long_running = runtime.register_and_get("name", LoopbackAgent) + long_running = await runtime.register_and_get("name", LoopbackAgent) run_context = runtime.start() with pytest.raises(InterventionException): _response = await runtime.send_message(MessageType(), recipient=long_running) await run_context.stop() - long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore + long_running_agent: LoopbackAgent = await runtime._get_agent(long_running) # type: ignore assert long_running_agent.num_calls == 1 diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index dca93a700..8cef7b16c 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -14,31 +14,31 @@ async def test_agent_names_must_be_unique() -> None: assert agent.id == id return agent - agent1 = runtime.register_and_get("name1", agent_factory) + agent1 = await runtime.register_and_get("name1", agent_factory) assert agent1 == AgentId("name1", "default") with pytest.raises(ValueError): - _agent1 = runtime.register_and_get("name1", NoopAgent) + _agent1 = await runtime.register_and_get("name1", NoopAgent) - _agent1 = runtime.register_and_get("name3", NoopAgent) + _agent1 = await runtime.register_and_get("name3", NoopAgent) @pytest.mark.asyncio async def test_register_receives_publish() -> None: runtime = SingleThreadedAgentRuntime() - runtime.register("name", LoopbackAgent) + await runtime.register("name", LoopbackAgent) run_context = runtime.start() await runtime.publish_message(MessageType(), namespace="default") await run_context.stop_when_idle() # Agent in default namespace should have received the message - long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore + long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name")) # type: ignore assert long_running_agent.num_calls == 1 # Agent in other namespace should not have received the message - other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore + other_long_running_agent: LoopbackAgent = await runtime._get_agent(await runtime.get("name", namespace="other")) # type: ignore assert other_long_running_agent.num_calls == 0 @@ -54,7 +54,7 @@ async def test_register_receives_publish_cascade() -> None: # Register agents for i in range(num_agents): - runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds)) + await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds)) run_context = runtime.start() @@ -67,5 +67,5 @@ async def test_register_receives_publish_cascade() -> None: # Check that each agent received the correct number of messages. for i in range(num_agents): - agent: CascadingAgent = runtime._get_agent(runtime.get(f"name{i}")) # type: ignore + agent: CascadingAgent = await runtime._get_agent(await runtime.get(f"name{i}")) # type: ignore assert agent.num_calls == total_num_calls_expected diff --git a/python/tests/test_state.py b/python/tests/test_state.py index 08c63a38d..b73e244bc 100644 --- a/python/tests/test_state.py +++ b/python/tests/test_state.py @@ -5,8 +5,8 @@ from agnext.application import SingleThreadedAgentRuntime from agnext.core import BaseAgent, CancellationToken -class StatefulAgent(BaseAgent): # type: ignore - def __init__(self) -> None: # type: ignore +class StatefulAgent(BaseAgent): + def __init__(self) -> None: super().__init__("A stateful agent", []) self.state = 0 @@ -14,7 +14,7 @@ class StatefulAgent(BaseAgent): # type: ignore def subscriptions(self) -> Sequence[type]: return [] - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> None: raise NotImplementedError def save_state(self) -> Mapping[str, Any]: @@ -28,8 +28,8 @@ class StatefulAgent(BaseAgent): # type: ignore async def test_agent_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - agent1_id = runtime.register_and_get("name1", StatefulAgent) - agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore + agent1_id = await runtime.register_and_get("name1", StatefulAgent) + agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore assert agent1.state == 0 agent1.state = 1 assert agent1.state == 1 @@ -46,19 +46,19 @@ async def test_agent_can_save_state() -> None: async def test_runtime_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - agent1_id = runtime.register_and_get("name1", StatefulAgent) - agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore + agent1_id = await runtime.register_and_get("name1", StatefulAgent) + agent1: StatefulAgent = await runtime._get_agent(agent1_id) # type: ignore assert agent1.state == 0 agent1.state = 1 assert agent1.state == 1 - runtime_state = runtime.save_state() + runtime_state = await runtime.save_state() runtime2 = SingleThreadedAgentRuntime() - agent2_id = runtime2.register_and_get("name1", StatefulAgent) - agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore + agent2_id = await runtime2.register_and_get("name1", StatefulAgent) + agent2: StatefulAgent = await runtime2._get_agent(agent2_id) # type: ignore - runtime2.load_state(runtime_state) + await runtime2.load_state(runtime_state) assert agent2.state == 1