mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Add task type that are messages to enable multi-modal tasks. (#4091)
* Add task type that are messages to enable multi-modal tasks. * fix test
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import AsyncGenerator, List, Sequence
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import AgentMessage, ChatMessage, InnerMessage, TextMessage
|
||||
from ..messages import AgentMessage, ChatMessage, InnerMessage, MultiModalMessage, TextMessage
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC):
|
||||
@@ -54,7 +54,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | None = None,
|
||||
task: str | TextMessage | MultiModalMessage | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
@@ -62,10 +62,13 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[ChatMessage] = []
|
||||
output_messages: List[AgentMessage] = []
|
||||
if task is not None:
|
||||
msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
if isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, TextMessage | MultiModalMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
response = await self.on_messages(input_messages, cancellation_token)
|
||||
if response.inner_messages is not None:
|
||||
output_messages += response.inner_messages
|
||||
@@ -75,7 +78,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | None = None,
|
||||
task: str | TextMessage | MultiModalMessage | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
@@ -84,11 +87,15 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
cancellation_token = CancellationToken()
|
||||
input_messages: List[ChatMessage] = []
|
||||
output_messages: List[AgentMessage] = []
|
||||
if task is not None:
|
||||
msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
if isinstance(task, str):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, TextMessage | MultiModalMessage):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
async for message in self.on_messages_stream(input_messages, cancellation_token):
|
||||
if isinstance(message, Response):
|
||||
yield message.chat_message
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import AsyncGenerator, Protocol, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import AgentMessage
|
||||
from ..messages import AgentMessage, MultiModalMessage, TextMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,7 +23,7 @@ class TaskRunner(Protocol):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | None = None,
|
||||
task: str | TextMessage | MultiModalMessage | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the task and return the result.
|
||||
@@ -36,7 +36,7 @@ class TaskRunner(Protocol):
|
||||
def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | None = None,
|
||||
task: str | TextMessage | MultiModalMessage | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
|
||||
@@ -18,7 +18,7 @@ from autogen_core.components import ClosureAgent, TypeSubscription
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentMessage, TextMessage
|
||||
from ...messages import AgentMessage, MultiModalMessage, TextMessage
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
@@ -160,7 +160,7 @@ class BaseGroupChat(Team, ABC):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | None = None,
|
||||
task: str | TextMessage | MultiModalMessage | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
@@ -213,7 +213,7 @@ class BaseGroupChat(Team, ABC):
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | None = None,
|
||||
task: str | TextMessage | MultiModalMessage | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
@@ -266,10 +266,11 @@ class BaseGroupChat(Team, ABC):
|
||||
await self._init(self._runtime)
|
||||
|
||||
# Run the team by publishing the start message.
|
||||
if task is None:
|
||||
first_chat_message = None
|
||||
else:
|
||||
first_chat_message: TextMessage | MultiModalMessage | None = None
|
||||
if isinstance(task, str):
|
||||
first_chat_message = TextMessage(content=task, source="user")
|
||||
elif isinstance(task, TextMessage | MultiModalMessage):
|
||||
first_chat_message = task
|
||||
await self._runtime.publish_message(
|
||||
GroupChatStart(message=first_chat_message),
|
||||
topic_id=TopicId(type=self._group_topic_type, source=self._team_id),
|
||||
|
||||
@@ -8,7 +8,14 @@ from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from autogen_agentchat.messages import (
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from autogen_core.components import Image
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
@@ -202,3 +209,27 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
|
||||
# Generate a random base64 image.
|
||||
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
||||
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
|
||||
assert len(result.messages) == 2
|
||||
|
||||
@@ -18,6 +18,7 @@ from autogen_agentchat.messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
@@ -189,6 +190,26 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test message input.
|
||||
# Text message.
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
result_2 = await team.run(
|
||||
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
|
||||
)
|
||||
assert result == result_2
|
||||
|
||||
# Test multi-modal message.
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
result_2 = await team.run(
|
||||
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
||||
)
|
||||
assert result.messages[0].content == result_2.messages[0].content[0]
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
Reference in New Issue
Block a user