Add task type that are messages to enable multi-modal tasks. (#4091)

* Add task type that are messages to enable multi-modal tasks.

* fix test
This commit is contained in:
Eric Zhu
2024-11-07 21:38:41 -08:00
committed by GitHub
parent 9e388925d4
commit 5fa38b0166
5 changed files with 82 additions and 22 deletions

View File

@@ -8,7 +8,14 @@ from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.base import TaskResult
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from autogen_agentchat.messages import (
HandoffMessage,
MultiModalMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_core.components import Image
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
@@ -202,3 +209,27 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
assert len(result.messages) == 2

View File

@@ -18,6 +18,7 @@ from autogen_agentchat.messages import (
AgentMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
@@ -189,6 +190,26 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert message == result.messages[index]
index += 1
# Test message input.
# Text message.
mock.reset()
index = 0
await team.reset()
result_2 = await team.run(
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
)
assert result == result_2
# Test multi-modal message.
mock.reset()
index = 0
await team.reset()
result_2 = await team.run(
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
)
assert result.messages[0].content == result_2.messages[0].content[0]
assert result.messages[1:] == result_2.messages[1:]
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: