Agent factory can be async (#247)

This commit is contained in:
Jack Gerrits
2024-07-23 11:49:38 -07:00
committed by GitHub
parent 718fad6e0d
commit a52d3bab53
47 changed files with 352 additions and 299 deletions

View File

@@ -22,10 +22,12 @@ async def select_speaker(memory: ChatMemory[Message], client: ChatCompletionClie
history = "\n".join(history_messages)
# Construct agent roles.
roles = "\n".join([f"{agent.metadata['name']}: {agent.metadata['description']}".strip() for agent in agents])
roles = "\n".join(
[f"{(await agent.metadata)['name']}: {(await agent.metadata)['description']}".strip() for agent in agents]
)
# Construct agent list.
participants = str([agent.metadata["name"] for agent in agents])
participants = str([(await agent.metadata)["name"] for agent in agents])
# Select the next speaker.
select_speaker_prompt = f"""You are in a role play game. The following roles are available:
@@ -39,16 +41,22 @@ Read the above conversation. Then select the next role from {participants} to pl
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
response = await client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
mentions = mentioned_agents(response.content, agents)
mentions = await mentioned_agents(response.content, agents)
if len(mentions) != 1:
raise ValueError(f"Expected exactly one agent to be mentioned, but got {mentions}")
agent_name = list(mentions.keys())[0]
agent_index = next((i for i, agent in enumerate(agents) if agent.metadata["name"] == agent_name), None)
# Get the index of the selected agent by name
agent_index = 0
for i, agent in enumerate(agents):
if (await agent.metadata)["name"] == agent_name:
agent_index = i
break
assert agent_index is not None
return agent_index
def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
async def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Agent names will match under any of the following conditions (all case-sensitive):
- Exact name match
@@ -66,7 +74,7 @@ def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str
for agent in agents:
# Finds agent mentions, taking word boundaries into account,
# accommodates escaping underscores and underscores as spaces
name = agent.metadata["name"]
name = (await agent.metadata)["name"]
regex = (
r"(?<=\W)("
+ re.escape(name)

View File

@@ -170,7 +170,10 @@ Some additional points to consider:
# A reusable description of the team.
team = "\n".join(
[agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists]
[
agent.name + ": " + (await self.runtime.agent_metadata(agent))["description"]
for agent in self._specialists
]
)
names = ", ".join([agent.name for agent in self._specialists])