Add task type that are messages to enable multi-modal tasks. (#4091)

* Add task type that are messages to enable multi-modal tasks.

* fix test
This commit is contained in:
Eric Zhu
2024-11-07 21:38:41 -08:00
committed by GitHub
parent 9e388925d4
commit 5fa38b0166
5 changed files with 82 additions and 22 deletions

View File

@@ -4,7 +4,7 @@ from typing import AsyncGenerator, List, Sequence
from autogen_core.base import CancellationToken
from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, InnerMessage, TextMessage
from ..messages import AgentMessage, ChatMessage, InnerMessage, MultiModalMessage, TextMessage
class BaseChatAgent(ChatAgent, ABC):
@@ -54,7 +54,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def run(
self,
*,
task: str | None = None,
task: str | TextMessage | MultiModalMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
@@ -62,10 +62,13 @@ class BaseChatAgent(ChatAgent, ABC):
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentMessage] = []
if task is not None:
msg = TextMessage(content=task, source="user")
input_messages.append(msg)
output_messages.append(msg)
if isinstance(task, str):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, TextMessage | MultiModalMessage):
input_messages.append(task)
output_messages.append(task)
response = await self.on_messages(input_messages, cancellation_token)
if response.inner_messages is not None:
output_messages += response.inner_messages
@@ -75,7 +78,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def run_stream(
self,
*,
task: str | None = None,
task: str | TextMessage | MultiModalMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
@@ -84,11 +87,15 @@ class BaseChatAgent(ChatAgent, ABC):
cancellation_token = CancellationToken()
input_messages: List[ChatMessage] = []
output_messages: List[AgentMessage] = []
if task is not None:
msg = TextMessage(content=task, source="user")
input_messages.append(msg)
output_messages.append(msg)
yield msg
if isinstance(task, str):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, TextMessage | MultiModalMessage):
input_messages.append(task)
output_messages.append(task)
yield task
async for message in self.on_messages_stream(input_messages, cancellation_token):
if isinstance(message, Response):
yield message.chat_message

View File

@@ -3,7 +3,7 @@ from typing import AsyncGenerator, Protocol, Sequence
from autogen_core.base import CancellationToken
from ..messages import AgentMessage
from ..messages import AgentMessage, MultiModalMessage, TextMessage
@dataclass
@@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | None = None,
task: str | TextMessage | MultiModalMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
@@ -36,7 +36,7 @@ class TaskRunner(Protocol):
def run_stream(
self,
*,
task: str | None = None,
task: str | TextMessage | MultiModalMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result

View File

@@ -18,7 +18,7 @@ from autogen_core.components import ClosureAgent, TypeSubscription
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, TextMessage
from ...messages import AgentMessage, MultiModalMessage, TextMessage
from ._base_group_chat_manager import BaseGroupChatManager
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
@@ -160,7 +160,7 @@ class BaseGroupChat(Team, ABC):
async def run(
self,
*,
task: str | None = None,
task: str | TextMessage | MultiModalMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
@@ -213,7 +213,7 @@ class BaseGroupChat(Team, ABC):
async def run_stream(
self,
*,
task: str | None = None,
task: str | TextMessage | MultiModalMessage | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
@@ -266,10 +266,11 @@ class BaseGroupChat(Team, ABC):
await self._init(self._runtime)
# Run the team by publishing the start message.
if task is None:
first_chat_message = None
else:
first_chat_message: TextMessage | MultiModalMessage | None = None
if isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, TextMessage | MultiModalMessage):
first_chat_message = task
await self._runtime.publish_message(
GroupChatStart(message=first_chat_message),
topic_id=TopicId(type=self._group_topic_type, source=self._team_id),

View File

@@ -8,7 +8,14 @@ from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.base import TaskResult
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from autogen_agentchat.messages import (
HandoffMessage,
MultiModalMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_core.components import Image
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
@@ -202,3 +209,27 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
assert len(result.messages) == 2

View File

@@ -18,6 +18,7 @@ from autogen_agentchat.messages import (
AgentMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
@@ -189,6 +190,26 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert message == result.messages[index]
index += 1
# Test message input.
# Text message.
mock.reset()
index = 0
await team.reset()
result_2 = await team.run(
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
)
assert result == result_2
# Test multi-modal message.
mock.reset()
index = 0
await team.reset()
result_2 = await team.run(
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
)
assert result.messages[0].content == result_2.messages[0].content[0]
assert result.messages[1:] == result_2.messages[1:]
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: