Agent factory can be async (#247)

This commit is contained in:
Jack Gerrits
2024-07-23 11:49:38 -07:00
committed by GitHub
parent 718fad6e0d
commit a52d3bab53
47 changed files with 352 additions and 299 deletions

View File

@@ -15,19 +15,19 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(description="The current user interacting with you."),
)
runtime.register(
await runtime.register(
"orchestrator",
lambda: LedgerOrchestrator(
model_client=create_completion_client_from_env(), agents=[coder, executor, user_proxy]

View File

@@ -15,19 +15,19 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
executor = runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)

View File

@@ -18,16 +18,16 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
file_surfer = runtime.register_and_get_proxy(
file_surfer = await runtime.register_and_get_proxy(
"file_surfer",
lambda: FileSurfer(model_client=client),
)
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)

View File

@@ -13,10 +13,10 @@ from team_one.utils import LogHandler
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
fake1 = runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake2 = runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake3 = runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
task_message = UserMessage(content="Test Message", source="User")
run_context = runtime.start()

View File

@@ -19,16 +19,16 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
coder = runtime.register_and_get_proxy(
coder = await runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=client),
)
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)

View File

@@ -21,17 +21,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
web_surfer = runtime.register_and_get_proxy(
web_surfer = await runtime.register_and_get_proxy(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
user_proxy = runtime.register_and_get_proxy(
user_proxy = await runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
run_context = runtime.start()

View File

@@ -32,7 +32,7 @@ dependencies = [
"youtube-transcript-api",
"SpeechRecognition",
"pathvalidate",
"playwright"
"playwright",
]
[tool.hatch.envs.default]
@@ -45,7 +45,8 @@ dependencies = [
"aiofiles",
"types-aiofiles",
"types-requests",
"azure-identity"
"types-pillow",
"azure-identity",
]
[tool.hatch.envs.default.extra-scripts]
@@ -71,7 +72,13 @@ line-length = 120
fix = true
exclude = ["build", "dist", "page_script.js"]
target-version = "py310"
include = ["src/**", "examples/*.py"]
include = [
"src/**",
"examples/*.py",
"../../benchmarks/HumanEval/Templates/TeamOne/scenario.py",
"../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py",
"../../benchmarks/GAIA/TeamOne/TwoAgents/scenario.py",
]
[tool.ruff.format]
docstring-code-format = true
@@ -81,7 +88,11 @@ select = ["E", "F", "W", "B", "Q", "I", "ASYNC"]
ignore = ["F401", "E501"]
[tool.mypy]
files = ["src", "examples", "tests"]
files = [
"src",
"tests",
"examples",
]
strict = true
python_version = "3.10"
@@ -100,7 +111,14 @@ disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.pyright]
include = ["src", "tests", "examples"]
include = [
"src",
"tests",
"examples",
"../../benchmarks/HumanEval/Templates/TeamOne/scenario.py",
"../../benchmarks/HumanEval/Templates/TwoAgents/scenario.py",
"../../benchmarks/GAIA/Templates/TeamOne/scenario.py",
]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false

View File

@@ -69,7 +69,7 @@ class BaseOrchestrator(TypeRoutedAgent):
logger.info(
OrchestrationEvent(
source=f"{self.metadata['name']} (thought)",
message=f"Next speaker {next_agent.metadata['name']}" "",
message=f"Next speaker {(await next_agent.metadata)['name']}" "",
)
)

View File

@@ -68,7 +68,7 @@ def _draw_roi(
luminance = color[0] * 0.3 + color[1] * 0.59 + color[2] * 0.11
text_color = (0, 0, 0, 255) if luminance > 90 else (255, 255, 255, 255)
roi = [(rect["left"], rect["top"]), (rect["right"], rect["bottom"])]
roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"]))
label_location = (rect["right"], rect["top"])
label_anchor = "rb"

View File

@@ -79,16 +79,16 @@ class LedgerOrchestrator(BaseOrchestrator):
def _get_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
return self._ledger_prompt.format(task=task, team=team, names=names)
def _get_team_description(self) -> str:
async def _get_team_description(self) -> str:
team_description = ""
for agent in self._agents:
name = agent.metadata["name"]
description = agent.metadata["description"]
name = (await agent.metadata)["name"]
description = (await agent.metadata)["description"]
team_description += f"{name}: {description}\n"
return team_description
def _get_team_names(self) -> List[str]:
return [agent.metadata["name"] for agent in self._agents]
async def _get_team_names(self) -> List[str]:
return [(await agent.metadata)["name"] for agent in self._agents]
def _set_task_str(self, message: LLMMessage) -> None:
if len(self._chat_history) == 1:
@@ -112,7 +112,7 @@ class LedgerOrchestrator(BaseOrchestrator):
return False
async def _plan(self) -> str:
team_description = self._get_team_description()
team_description = await self._get_team_description()
# 1. GATHER FACTS
# create a closed book task and generate a response and update the chat history
@@ -144,8 +144,8 @@ class LedgerOrchestrator(BaseOrchestrator):
async def update_ledger(self) -> Dict[str, Any]:
max_json_retries = 10
team_description = self._get_team_description()
names = self._get_team_names()
team_description = await self._get_team_description()
names = await self._get_team_names()
ledger_prompt = self._get_ledger_prompt(self.task_str, team_description, names)
ledger_user_message = UserMessage(content=ledger_prompt, source=self.metadata["name"])
@@ -234,7 +234,7 @@ class LedgerOrchestrator(BaseOrchestrator):
next_agent_name = ledger_dict["next_speaker"]["answer"]
for agent in self._agents:
if agent.metadata["name"] == next_agent_name:
if (await agent.metadata)["name"] == next_agent_name:
# broadcast a new message
instruction = ledger_dict["instruction_or_question"]["answer"]
user_message = UserMessage(content=instruction, source=self.metadata["name"])