Move tools to agent in agentchat; refactored logging to support tool events (#3665)

* Move tool to agent; refactor logging in agentchat

* Update notebook
This commit is contained in:
Eric Zhu
2024-10-07 09:38:24 -07:00
committed by GitHub
parent be5c0b5d3e
commit 54eaa2bb4e
11 changed files with 236 additions and 124 deletions

View File

@@ -1,5 +1,6 @@
from ._base_chat_agent import (
BaseChatAgent,
BaseToolUseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
@@ -13,6 +14,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"ChatMessage",
"TextMessage",
"MultiModalMessage",

View File

@@ -4,6 +4,7 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tools import Tool
from pydantic import BaseModel
@@ -49,7 +50,7 @@ class StopMessage(BaseMessage):
"""The content for the stop message."""
ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
"""A message used by agents in a team."""
@@ -79,3 +80,21 @@ class BaseChatAgent(ABC):
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
...
class BaseToolUseChatAgent(BaseChatAgent):
"""Base class for a chat agent that can use tools.
Subclass this base class to create an agent class that uses tools by returning
ToolCallMessage message from the :meth:`on_messages` method and receiving
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
"""
def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
super().__init__(name, description)
self._registered_tools = registered_tools
@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
return self._registered_tools

View File

@@ -10,10 +10,10 @@ from autogen_core.components.models import (
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import ToolSchema
from autogen_core.components.tools import Tool
from ._base_chat_agent import (
BaseChatAgent,
BaseToolUseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
@@ -23,7 +23,7 @@ from ._base_chat_agent import (
)
class ToolUseAssistantAgent(BaseChatAgent):
class ToolUseAssistantAgent(BaseToolUseChatAgent):
"""An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
@@ -33,15 +33,15 @@ class ToolUseAssistantAgent(BaseChatAgent):
self,
name: str,
model_client: ChatCompletionClient,
tool_schema: List[ToolSchema],
registered_tools: List[Tool],
*,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply 'TERMINATE' in the end when the task is completed.",
):
super().__init__(name=name, description=description)
super().__init__(name=name, description=description, registered_tools=registered_tools)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tool_schema = tool_schema
self._tool_schema = [tool.schema for tool in registered_tools]
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:

View File

@@ -1,8 +1,7 @@
from typing import Optional
from autogen_core.base import AgentId
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel
from ..agents import ChatMessage
from ..agents import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
class ContentPublishEvent(BaseModel):
@@ -11,9 +10,13 @@ class ContentPublishEvent(BaseModel):
content of the event.
"""
agent_message: ChatMessage
agent_message: TextMessage | MultiModalMessage | StopMessage
"""The message published by the agent."""
source: Optional[str] = None
source: AgentId | None = None
"""The agent ID that published the message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ContentRequestEvent(BaseModel):
@@ -22,3 +25,27 @@ class ContentRequestEvent(BaseModel):
"""
...
class ToolCallEvent(BaseModel):
"""An event produced when requesting a tool call."""
agent_message: ToolCallMessage
"""The tool call message."""
source: AgentId
"""The sender of the tool call message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolCallResultEvent(BaseModel):
"""An event produced when a tool call is completed."""
agent_message: ToolCallResultMessage
"""The tool call result message."""
source: AgentId
"""The sender of the tool call result message."""
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -3,13 +3,14 @@ import logging
import sys
from dataclasses import asdict, is_dataclass
from datetime import datetime
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Union
from autogen_core.base import AgentId
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from ..agents import ChatMessage, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from ._events import ContentPublishEvent
from ._events import ContentPublishEvent, ToolCallEvent, ToolCallResultEvent
EVENT_LOGGER_NAME = "autogen_agentchat.events"
ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[FunctionExecutionResult]]
@@ -17,7 +18,8 @@ ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[Funct
class BaseLogHandler(logging.Handler):
def serialize_content(
self, content: Union[ContentType, Sequence[ChatMessage], ChatMessage]
self,
content: Union[ContentType, ChatMessage],
) -> Union[List[Any], Dict[str, Any], str]:
if isinstance(content, (str, list)):
return content
@@ -41,19 +43,35 @@ class BaseLogHandler(logging.Handler):
class ConsoleLogHandler(BaseLogHandler):
def _format_message(
self,
*,
source_agent_id: AgentId | None,
message: ChatMessage,
timestamp: str,
) -> str:
body = f"{self.serialize_content(message.content)}\nFrom: {message.source}"
if source_agent_id is None:
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}]:\033[0m\n" f"\n{body}"
else:
# Display the source agent type rather than agent ID for better readability.
# Also in AgentChat the agent type is unique for each agent.
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}], {source_agent_id.type}:\033[0m\n" f"\n{body}"
return console_message
def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
console_message = (
f"\n{'-'*75} \n"
f"\033[91m[{ts}], {record.msg.agent_message.source}:\033[0m\n"
f"\n{self.serialize_content(record.msg.agent_message.content)}"
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
sys.stdout.write(
self._format_message(
source_agent_id=record.msg.source,
message=record.msg.agent_message,
timestamp=ts,
)
sys.stdout.write(console_message)
sys.stdout.flush()
except Exception:
self.handleError(record)
)
sys.stdout.flush()
else:
raise ValueError(f"Unexpected log record: {record.msg}")
class FileLogHandler(BaseLogHandler):
@@ -62,32 +80,37 @@ class FileLogHandler(BaseLogHandler):
self.filename = filename
self.file_handler = logging.FileHandler(filename)
def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"source": record.msg.agent_message.source,
"message": self.serialize_content(record.msg.agent_message.content),
"type": "OrchestrationEvent",
},
default=self.json_serializer,
)
def _format_entry(self, *, source: AgentId | None, message: ChatMessage, timestamp: str) -> Dict[str, Any]:
return {
"timestamp": timestamp,
"source": source,
"message": self.serialize_content(message),
"type": "OrchestrationEvent",
}
file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)
except Exception:
self.handleError(record)
def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
log_entry = json.dumps(
self._format_entry(
source=record.msg.source,
message=record.msg.agent_message,
timestamp=ts,
),
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)
def close(self) -> None:
self.file_handler.close()

View File

@@ -8,10 +8,12 @@ from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tool_agent import ToolException
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from .._events import ContentPublishEvent, ContentRequestEvent
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
from .._logging import EVENT_LOGGER_NAME
from ._sequential_routed_agent import SequentialRoutedAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class BaseChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
@@ -21,16 +23,15 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
agent (BaseChatAgent): The agent to delegate message handling to.
tool_agent_type (AgentType): The agent type of the tool agent to use for tool calls.
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
"""
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType) -> None:
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)
self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME)
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None
@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
@@ -48,38 +49,43 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
to the delegate agent and publish the response."""
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
# Handle tool calls.
while isinstance(response, ToolCallMessage):
self._logger.info(ContentPublishEvent(agent_message=response))
if self._tool_agent_id is not None:
# Handle tool calls.
while isinstance(response, ToolCallMessage):
# Log the tool call.
event_logger.info(ToolCallEvent(agent_message=response, source=self.id))
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# TODO: use logging instead of print
self._logger.info(ContentPublishEvent(agent_message=feedback, source=self._tool_agent_id.type))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)
)
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# Log the feedback.
event_logger.info(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
# Publish the response.
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
self._message_buffer.clear()
await self.publish_message(
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
ContentPublishEvent(agent_message=response, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)

View File

@@ -9,6 +9,8 @@ from .._events import ContentPublishEvent, ContentRequestEvent
from .._logging import EVENT_LOGGER_NAME
from ._sequential_routed_agent import SequentialRoutedAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class BaseGroupChatManager(SequentialRoutedAgent):
"""Base class for a group chat manager that manages a group chat with multiple participants.
@@ -50,7 +52,6 @@ class BaseGroupChatManager(SequentialRoutedAgent):
self._participant_topic_types = participant_topic_types
self._participant_descriptions = participant_descriptions
self._message_thread: List[ChatMessage] = []
self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME + ".agentchatchat")
@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
@@ -63,9 +64,7 @@ class BaseGroupChatManager(SequentialRoutedAgent):
assert ctx.topic_id is not None
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
# TODO: use something else other than print.
self._logger.info(ContentPublishEvent(agent_message=message.agent_message))
event_logger.info(message)
# Process event from parent.
if ctx.topic_id.type == self._parent_topic_type:

View File

@@ -9,7 +9,7 @@ from autogen_core.components.tools import Tool
from autogen_agentchat.agents._base_chat_agent import ChatMessage
from ...agents import BaseChatAgent, TextMessage
from ...agents import BaseChatAgent, BaseToolUseChatAgent, TextMessage
from .._base_team import BaseTeam, TeamRunResult
from .._events import ContentPublishEvent, ContentRequestEvent
from ._base_chat_agent_container import BaseChatAgentContainer
@@ -37,8 +37,8 @@ class RoundRobinGroupChat(BaseTeam):
from autogen_agentchat.agents import ToolUseAssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
assistant = ToolUseAssistantAgent("Assistant", model_client=..., tool_schema=[...])
team = RoundRobinGroupChat([assistant], tools=[...])
assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...)
team = RoundRobinGroupChat([assistant])
await team.run("What's the weather in New York?")
A team with multiple participants:
@@ -55,17 +55,21 @@ class RoundRobinGroupChat(BaseTeam):
"""
def __init__(self, participants: List[BaseChatAgent], *, tools: List[Tool] | None = None):
def __init__(self, participants: List[BaseChatAgent]):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
if len(participants) != len(set(participant.name for participant in participants)):
raise ValueError("The participant names must be unique.")
for participant in participants:
if isinstance(participant, BaseToolUseChatAgent) and not participant.registered_tools:
raise ValueError(
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
)
self._participants = participants
self._team_id = str(uuid.uuid4())
self._tools = tools or []
def _create_factory(
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None
) -> Callable[[], BaseChatAgentContainer]:
def _factory() -> BaseChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
@@ -76,6 +80,16 @@ class RoundRobinGroupChat(BaseTeam):
return _factory
def _create_tool_agent_factory(
self,
caller_name: str,
tools: List[Tool],
) -> Callable[[], ToolAgent]:
def _factory() -> ToolAgent:
return ToolAgent(f"Tool agent for {caller_name}", tools)
return _factory
async def run(self, task: str) -> TeamRunResult:
"""Run the team and return the result."""
# Create the runtime.
@@ -87,16 +101,23 @@ class RoundRobinGroupChat(BaseTeam):
group_topic_type = "round_robin_group_topic"
team_topic_type = "team_topic"
# Register the tool agent.
tool_agent_type = await ToolAgent.register(
runtime, "tool_agent", lambda: ToolAgent("Tool agent for round-robin group chat", self._tools)
)
# No subscriptions are needed for the tool agent, which will be called via direct messages.
# Register participants.
participant_topic_types: List[str] = []
participant_descriptions: List[str] = []
for participant in self._participants:
if isinstance(participant, BaseToolUseChatAgent):
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
# Register the tool agent.
tool_agent_type = await ToolAgent.register(
runtime,
f"tool_agent_for_{participant.name}",
self._create_tool_agent_factory(participant.name, participant.registered_tools),
)
# No subscriptions are needed for the tool agent, which will be called via direct messages.
else:
# No tool agent is needed.
tool_agent_type = None
# Use the participant name as the agent type and topic type.
agent_type = participant.name
topic_type = participant.name