mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Add task type that are messages to enable multi-modal tasks. (#4091)
* Add task type that are messages to enable multi-modal tasks. * fix test
This commit is contained in:
@@ -8,7 +8,14 @@ from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent, Handoff
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from autogen_agentchat.messages import (
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from autogen_core.components import Image
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
@@ -202,3 +209,27 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
|
||||
# Generate a random base64 image.
|
||||
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
||||
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
|
||||
assert len(result.messages) == 2
|
||||
|
||||
@@ -18,6 +18,7 @@ from autogen_agentchat.messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
@@ -189,6 +190,26 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test message input.
|
||||
# Text message.
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
result_2 = await team.run(
|
||||
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
|
||||
)
|
||||
assert result == result_2
|
||||
|
||||
# Test multi-modal message.
|
||||
mock.reset()
|
||||
index = 0
|
||||
await team.reset()
|
||||
result_2 = await team.run(
|
||||
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
||||
)
|
||||
assert result.messages[0].content == result_2.messages[0].content[0]
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
Reference in New Issue
Block a user