mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Move tools to agent in agentchat; refactored logging to support tool events (#3665)
* Move tool to agent; refactor logging in agentchat * Update notebook
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from ._base_chat_agent import (
|
||||
BaseChatAgent,
|
||||
BaseToolUseChatAgent,
|
||||
ChatMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
@@ -13,6 +14,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"BaseToolUseChatAgent",
|
||||
"ChatMessage",
|
||||
"TextMessage",
|
||||
"MultiModalMessage",
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Sequence
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from autogen_core.components.tools import Tool
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -49,7 +50,7 @@ class StopMessage(BaseMessage):
|
||||
"""The content for the stop message."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
|
||||
"""A message used by agents in a team."""
|
||||
|
||||
|
||||
@@ -79,3 +80,21 @@ class BaseChatAgent(ABC):
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
"""Handle incoming messages and return a response message."""
|
||||
...
|
||||
|
||||
|
||||
class BaseToolUseChatAgent(BaseChatAgent):
|
||||
"""Base class for a chat agent that can use tools.
|
||||
|
||||
Subclass this base class to create an agent class that uses tools by returning
|
||||
ToolCallMessage message from the :meth:`on_messages` method and receiving
|
||||
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
|
||||
super().__init__(name, description)
|
||||
self._registered_tools = registered_tools
|
||||
|
||||
@property
|
||||
def registered_tools(self) -> List[Tool]:
|
||||
"""The list of tools that the agent can use."""
|
||||
return self._registered_tools
|
||||
|
||||
@@ -10,10 +10,10 @@ from autogen_core.components.models import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import ToolSchema
|
||||
from autogen_core.components.tools import Tool
|
||||
|
||||
from ._base_chat_agent import (
|
||||
BaseChatAgent,
|
||||
BaseToolUseChatAgent,
|
||||
ChatMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
@@ -23,7 +23,7 @@ from ._base_chat_agent import (
|
||||
)
|
||||
|
||||
|
||||
class ToolUseAssistantAgent(BaseChatAgent):
|
||||
class ToolUseAssistantAgent(BaseToolUseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
|
||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||
@@ -33,15 +33,15 @@ class ToolUseAssistantAgent(BaseChatAgent):
|
||||
self,
|
||||
name: str,
|
||||
model_client: ChatCompletionClient,
|
||||
tool_schema: List[ToolSchema],
|
||||
registered_tools: List[Tool],
|
||||
*,
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply 'TERMINATE' in the end when the task is completed.",
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
super().__init__(name=name, description=description, registered_tools=registered_tools)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
self._tool_schema = tool_schema
|
||||
self._tool_schema = [tool.schema for tool in registered_tools]
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import Optional
|
||||
from autogen_core.base import AgentId
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents import ChatMessage
|
||||
from ..agents import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
|
||||
|
||||
class ContentPublishEvent(BaseModel):
|
||||
@@ -11,9 +10,13 @@ class ContentPublishEvent(BaseModel):
|
||||
content of the event.
|
||||
"""
|
||||
|
||||
agent_message: ChatMessage
|
||||
agent_message: TextMessage | MultiModalMessage | StopMessage
|
||||
"""The message published by the agent."""
|
||||
source: Optional[str] = None
|
||||
|
||||
source: AgentId | None = None
|
||||
"""The agent ID that published the message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ContentRequestEvent(BaseModel):
|
||||
@@ -22,3 +25,27 @@ class ContentRequestEvent(BaseModel):
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ToolCallEvent(BaseModel):
|
||||
"""An event produced when requesting a tool call."""
|
||||
|
||||
agent_message: ToolCallMessage
|
||||
"""The tool call message."""
|
||||
|
||||
source: AgentId
|
||||
"""The sender of the tool call message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolCallResultEvent(BaseModel):
|
||||
"""An event produced when a tool call is completed."""
|
||||
|
||||
agent_message: ToolCallResultMessage
|
||||
"""The tool call result message."""
|
||||
|
||||
source: AgentId
|
||||
"""The sender of the tool call result message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@@ -3,13 +3,14 @@ import logging
|
||||
import sys
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Sequence, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from autogen_core.base import AgentId
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
|
||||
from ..agents import ChatMessage, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from ._events import ContentPublishEvent
|
||||
from ._events import ContentPublishEvent, ToolCallEvent, ToolCallResultEvent
|
||||
|
||||
EVENT_LOGGER_NAME = "autogen_agentchat.events"
|
||||
ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[FunctionExecutionResult]]
|
||||
@@ -17,7 +18,8 @@ ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[Funct
|
||||
|
||||
class BaseLogHandler(logging.Handler):
|
||||
def serialize_content(
|
||||
self, content: Union[ContentType, Sequence[ChatMessage], ChatMessage]
|
||||
self,
|
||||
content: Union[ContentType, ChatMessage],
|
||||
) -> Union[List[Any], Dict[str, Any], str]:
|
||||
if isinstance(content, (str, list)):
|
||||
return content
|
||||
@@ -41,19 +43,35 @@ class BaseLogHandler(logging.Handler):
|
||||
|
||||
|
||||
class ConsoleLogHandler(BaseLogHandler):
|
||||
def _format_message(
|
||||
self,
|
||||
*,
|
||||
source_agent_id: AgentId | None,
|
||||
message: ChatMessage,
|
||||
timestamp: str,
|
||||
) -> str:
|
||||
body = f"{self.serialize_content(message.content)}\nFrom: {message.source}"
|
||||
if source_agent_id is None:
|
||||
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}]:\033[0m\n" f"\n{body}"
|
||||
else:
|
||||
# Display the source agent type rather than agent ID for better readability.
|
||||
# Also in AgentChat the agent type is unique for each agent.
|
||||
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}], {source_agent_id.type}:\033[0m\n" f"\n{body}"
|
||||
return console_message
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, ContentPublishEvent):
|
||||
console_message = (
|
||||
f"\n{'-'*75} \n"
|
||||
f"\033[91m[{ts}], {record.msg.agent_message.source}:\033[0m\n"
|
||||
f"\n{self.serialize_content(record.msg.agent_message.content)}"
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
|
||||
sys.stdout.write(
|
||||
self._format_message(
|
||||
source_agent_id=record.msg.source,
|
||||
message=record.msg.agent_message,
|
||||
timestamp=ts,
|
||||
)
|
||||
sys.stdout.write(console_message)
|
||||
sys.stdout.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
)
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
raise ValueError(f"Unexpected log record: {record.msg}")
|
||||
|
||||
|
||||
class FileLogHandler(BaseLogHandler):
|
||||
@@ -62,32 +80,37 @@ class FileLogHandler(BaseLogHandler):
|
||||
self.filename = filename
|
||||
self.file_handler = logging.FileHandler(filename)
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, ContentPublishEvent):
|
||||
log_entry = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"source": record.msg.agent_message.source,
|
||||
"message": self.serialize_content(record.msg.agent_message.content),
|
||||
"type": "OrchestrationEvent",
|
||||
},
|
||||
default=self.json_serializer,
|
||||
)
|
||||
def _format_entry(self, *, source: AgentId | None, message: ChatMessage, timestamp: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"timestamp": timestamp,
|
||||
"source": source,
|
||||
"message": self.serialize_content(message),
|
||||
"type": "OrchestrationEvent",
|
||||
}
|
||||
|
||||
file_record = logging.LogRecord(
|
||||
name=record.name,
|
||||
level=record.levelno,
|
||||
pathname=record.pathname,
|
||||
lineno=record.lineno,
|
||||
msg=log_entry,
|
||||
args=(),
|
||||
exc_info=record.exc_info,
|
||||
)
|
||||
self.file_handler.emit(file_record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
|
||||
log_entry = json.dumps(
|
||||
self._format_entry(
|
||||
source=record.msg.source,
|
||||
message=record.msg.agent_message,
|
||||
timestamp=ts,
|
||||
),
|
||||
default=self.json_serializer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected log record: {record.msg}")
|
||||
file_record = logging.LogRecord(
|
||||
name=record.name,
|
||||
level=record.levelno,
|
||||
pathname=record.pathname,
|
||||
lineno=record.lineno,
|
||||
msg=log_entry,
|
||||
args=(),
|
||||
exc_info=record.exc_info,
|
||||
)
|
||||
self.file_handler.emit(file_record)
|
||||
|
||||
def close(self) -> None:
|
||||
self.file_handler.close()
|
||||
|
||||
@@ -8,10 +8,12 @@ from autogen_core.components.models import FunctionExecutionResult
|
||||
from autogen_core.components.tool_agent import ToolException
|
||||
|
||||
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from .._events import ContentPublishEvent, ContentRequestEvent
|
||||
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
|
||||
from .._logging import EVENT_LOGGER_NAME
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class BaseChatAgentContainer(SequentialRoutedAgent):
|
||||
"""A core agent class that delegates message handling to an
|
||||
@@ -21,16 +23,15 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
agent (BaseChatAgent): The agent to delegate message handling to.
|
||||
tool_agent_type (AgentType): The agent type of the tool agent to use for tool calls.
|
||||
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType) -> None:
|
||||
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
|
||||
super().__init__(description=agent.description)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
|
||||
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)
|
||||
self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None
|
||||
|
||||
@event
|
||||
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
|
||||
@@ -48,38 +49,43 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
|
||||
to the delegate agent and publish the response."""
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
|
||||
# Handle tool calls.
|
||||
while isinstance(response, ToolCallMessage):
|
||||
self._logger.info(ContentPublishEvent(agent_message=response))
|
||||
if self._tool_agent_id is not None:
|
||||
# Handle tool calls.
|
||||
while isinstance(response, ToolCallMessage):
|
||||
# Log the tool call.
|
||||
event_logger.info(ToolCallEvent(agent_message=response, source=self.id))
|
||||
|
||||
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
|
||||
*[
|
||||
self.send_message(
|
||||
message=call,
|
||||
recipient=self._tool_agent_id,
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
for call in response.content
|
||||
]
|
||||
)
|
||||
# Combine the results in to a single response and handle exceptions.
|
||||
function_results: List[FunctionExecutionResult] = []
|
||||
for result in results:
|
||||
if isinstance(result, FunctionExecutionResult):
|
||||
function_results.append(result)
|
||||
elif isinstance(result, ToolException):
|
||||
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
|
||||
elif isinstance(result, BaseException):
|
||||
raise result # Unexpected exception.
|
||||
# Create a new tool call result message.
|
||||
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
|
||||
# TODO: use logging instead of print
|
||||
self._logger.info(ContentPublishEvent(agent_message=feedback, source=self._tool_agent_id.type))
|
||||
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
|
||||
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
|
||||
*[
|
||||
self.send_message(
|
||||
message=call,
|
||||
recipient=self._tool_agent_id,
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
for call in response.content
|
||||
]
|
||||
)
|
||||
# Combine the results in to a single response and handle exceptions.
|
||||
function_results: List[FunctionExecutionResult] = []
|
||||
for result in results:
|
||||
if isinstance(result, FunctionExecutionResult):
|
||||
function_results.append(result)
|
||||
elif isinstance(result, ToolException):
|
||||
function_results.append(
|
||||
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)
|
||||
)
|
||||
elif isinstance(result, BaseException):
|
||||
raise result # Unexpected exception.
|
||||
# Create a new tool call result message.
|
||||
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
|
||||
# Log the feedback.
|
||||
event_logger.info(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id))
|
||||
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
|
||||
|
||||
# Publish the response.
|
||||
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
|
||||
ContentPublishEvent(agent_message=response, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@ from .._events import ContentPublishEvent, ContentRequestEvent
|
||||
from .._logging import EVENT_LOGGER_NAME
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class BaseGroupChatManager(SequentialRoutedAgent):
|
||||
"""Base class for a group chat manager that manages a group chat with multiple participants.
|
||||
@@ -50,7 +52,6 @@ class BaseGroupChatManager(SequentialRoutedAgent):
|
||||
self._participant_topic_types = participant_topic_types
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[ChatMessage] = []
|
||||
self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME + ".agentchatchat")
|
||||
|
||||
@event
|
||||
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
|
||||
@@ -63,9 +64,7 @@ class BaseGroupChatManager(SequentialRoutedAgent):
|
||||
assert ctx.topic_id is not None
|
||||
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
|
||||
|
||||
# TODO: use something else other than print.
|
||||
|
||||
self._logger.info(ContentPublishEvent(agent_message=message.agent_message))
|
||||
event_logger.info(message)
|
||||
|
||||
# Process event from parent.
|
||||
if ctx.topic_id.type == self._parent_topic_type:
|
||||
|
||||
@@ -9,7 +9,7 @@ from autogen_core.components.tools import Tool
|
||||
|
||||
from autogen_agentchat.agents._base_chat_agent import ChatMessage
|
||||
|
||||
from ...agents import BaseChatAgent, TextMessage
|
||||
from ...agents import BaseChatAgent, BaseToolUseChatAgent, TextMessage
|
||||
from .._base_team import BaseTeam, TeamRunResult
|
||||
from .._events import ContentPublishEvent, ContentRequestEvent
|
||||
from ._base_chat_agent_container import BaseChatAgentContainer
|
||||
@@ -37,8 +37,8 @@ class RoundRobinGroupChat(BaseTeam):
|
||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
|
||||
assistant = ToolUseAssistantAgent("Assistant", model_client=..., tool_schema=[...])
|
||||
team = RoundRobinGroupChat([assistant], tools=[...])
|
||||
assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...)
|
||||
team = RoundRobinGroupChat([assistant])
|
||||
await team.run("What's the weather in New York?")
|
||||
|
||||
A team with multiple participants:
|
||||
@@ -55,17 +55,21 @@ class RoundRobinGroupChat(BaseTeam):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, participants: List[BaseChatAgent], *, tools: List[Tool] | None = None):
|
||||
def __init__(self, participants: List[BaseChatAgent]):
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
if len(participants) != len(set(participant.name for participant in participants)):
|
||||
raise ValueError("The participant names must be unique.")
|
||||
for participant in participants:
|
||||
if isinstance(participant, BaseToolUseChatAgent) and not participant.registered_tools:
|
||||
raise ValueError(
|
||||
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
|
||||
)
|
||||
self._participants = participants
|
||||
self._team_id = str(uuid.uuid4())
|
||||
self._tools = tools or []
|
||||
|
||||
def _create_factory(
|
||||
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType
|
||||
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None
|
||||
) -> Callable[[], BaseChatAgentContainer]:
|
||||
def _factory() -> BaseChatAgentContainer:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
@@ -76,6 +80,16 @@ class RoundRobinGroupChat(BaseTeam):
|
||||
|
||||
return _factory
|
||||
|
||||
def _create_tool_agent_factory(
|
||||
self,
|
||||
caller_name: str,
|
||||
tools: List[Tool],
|
||||
) -> Callable[[], ToolAgent]:
|
||||
def _factory() -> ToolAgent:
|
||||
return ToolAgent(f"Tool agent for {caller_name}", tools)
|
||||
|
||||
return _factory
|
||||
|
||||
async def run(self, task: str) -> TeamRunResult:
|
||||
"""Run the team and return the result."""
|
||||
# Create the runtime.
|
||||
@@ -87,16 +101,23 @@ class RoundRobinGroupChat(BaseTeam):
|
||||
group_topic_type = "round_robin_group_topic"
|
||||
team_topic_type = "team_topic"
|
||||
|
||||
# Register the tool agent.
|
||||
tool_agent_type = await ToolAgent.register(
|
||||
runtime, "tool_agent", lambda: ToolAgent("Tool agent for round-robin group chat", self._tools)
|
||||
)
|
||||
# No subscriptions are needed for the tool agent, which will be called via direct messages.
|
||||
|
||||
# Register participants.
|
||||
participant_topic_types: List[str] = []
|
||||
participant_descriptions: List[str] = []
|
||||
for participant in self._participants:
|
||||
if isinstance(participant, BaseToolUseChatAgent):
|
||||
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
|
||||
# Register the tool agent.
|
||||
tool_agent_type = await ToolAgent.register(
|
||||
runtime,
|
||||
f"tool_agent_for_{participant.name}",
|
||||
self._create_tool_agent_factory(participant.name, participant.registered_tools),
|
||||
)
|
||||
# No subscriptions are needed for the tool agent, which will be called via direct messages.
|
||||
else:
|
||||
# No tool agent is needed.
|
||||
tool_agent_type = None
|
||||
|
||||
# Use the participant name as the agent type and topic type.
|
||||
agent_type = participant.name
|
||||
topic_type = participant.name
|
||||
|
||||
Reference in New Issue
Block a user