mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-08 00:35:28 -05:00
* LMM Code added * LLaVA notebook update * Test cases and Notebook modified for OpenAI v1 * Move LMM into contrib To resolve test issues and deploy issues In the future, we can install pillow by default, and then move back LMM agents into agentchat * LMM test setup update * try...except... clause for LMM tests * disable patch for llava agent test To resolve dependencies issue for build * Add LMM Blog * Change docstring for LMM agents * Docstring update patch * llava: insert reply at position 1 now So, it can still handle human_input_mode and max_consecutive_reply * Resolve comments Fixing: typos, blogs, yml, and add OpenAIWrapper * Signature typo fix for LMM agent: system_message * Update LMM "content" from latest OpenAI release Reference https://platform.openai.com/docs/guides/vision * update LMM test according to latest OpenAI release * Fully support GPT-4V now 1. Add a notebook for GPT-4V. LLava notebook also updated. 2. img_utils updated 3. GPT-4V formatter now return base64 image with mime type 4. Infer mime type directly from b64 image content (while loading without suffix) 5. Test cases modified according to all the related changes. * GPT-4V link updated in blog --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
194 lines
7.2 KiB
Python
194 lines
7.2 KiB
Python
import base64
|
|
import os
|
|
import pdb
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
try:
|
|
from PIL import Image
|
|
|
|
from autogen.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formater
|
|
except ImportError:
|
|
skip = True
|
|
else:
|
|
skip = False
|
|
|
|
|
|
base64_encoded_image = (
|
|
""
|
|
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
|
)
|
|
|
|
raw_encoded_image = (
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
|
|
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestGetImageData(unittest.TestCase):
|
|
def test_http_image(self):
|
|
with patch("requests.get") as mock_get:
|
|
mock_response = requests.Response()
|
|
mock_response.status_code = 200
|
|
mock_response._content = b"fake image content"
|
|
mock_get.return_value = mock_response
|
|
|
|
result = get_image_data("http://example.com/image.png")
|
|
self.assertEqual(result, base64.b64encode(b"fake image content").decode("utf-8"))
|
|
|
|
def test_base64_encoded_image(self):
|
|
result = get_image_data(base64_encoded_image)
|
|
self.assertEqual(result, base64_encoded_image.split(",", 1)[1])
|
|
|
|
def test_local_image(self):
|
|
# Create a temporary file to simulate a local image file.
|
|
temp_file = "_temp.png"
|
|
|
|
image = Image.new("RGB", (60, 30), color=(73, 109, 137))
|
|
image.save(temp_file)
|
|
|
|
result = get_image_data(temp_file)
|
|
with open(temp_file, "rb") as temp_image_file:
|
|
temp_image_file.seek(0)
|
|
expected_content = base64.b64encode(temp_image_file.read()).decode("utf-8")
|
|
|
|
self.assertEqual(result, expected_content)
|
|
os.remove(temp_file)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestLlavaFormater(unittest.TestCase):
|
|
def test_no_images(self):
|
|
"""
|
|
Test the llava_formater function with a prompt containing no images.
|
|
"""
|
|
prompt = "This is a test."
|
|
expected_output = (prompt, [])
|
|
result = llava_formater(prompt)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
@patch("autogen.img_utils.get_image_data")
|
|
def test_with_images(self, mock_get_image_data):
|
|
"""
|
|
Test the llava_formater function with a prompt containing images.
|
|
"""
|
|
# Mock the get_image_data function to return a fixed string.
|
|
mock_get_image_data.return_value = raw_encoded_image
|
|
|
|
prompt = "This is a test with an image <img http://example.com/image.png>."
|
|
expected_output = ("This is a test with an image <image>.", [raw_encoded_image])
|
|
result = llava_formater(prompt)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
@patch("autogen.img_utils.get_image_data")
|
|
def test_with_ordered_images(self, mock_get_image_data):
|
|
"""
|
|
Test the llava_formater function with ordered image tokens.
|
|
"""
|
|
# Mock the get_image_data function to return a fixed string.
|
|
mock_get_image_data.return_value = raw_encoded_image
|
|
|
|
prompt = "This is a test with an image <img http://example.com/image.png>."
|
|
expected_output = ("This is a test with an image <image 0>.", [raw_encoded_image])
|
|
result = llava_formater(prompt, order_image_tokens=True)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestGpt4vFormatter(unittest.TestCase):
|
|
def test_no_images(self):
|
|
"""
|
|
Test the gpt4v_formatter function with a prompt containing no images.
|
|
"""
|
|
prompt = "This is a test."
|
|
expected_output = [{"type": "text", "text": prompt}]
|
|
result = gpt4v_formatter(prompt)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
@patch("autogen.img_utils.get_image_data")
|
|
def test_with_images(self, mock_get_image_data):
|
|
"""
|
|
Test the gpt4v_formatter function with a prompt containing images.
|
|
"""
|
|
# Mock the get_image_data function to return a fixed string.
|
|
mock_get_image_data.return_value = raw_encoded_image
|
|
|
|
prompt = "This is a test with an image <img http://example.com/image.png>."
|
|
expected_output = [
|
|
{"type": "text", "text": "This is a test with an image "},
|
|
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
|
{"type": "text", "text": "."},
|
|
]
|
|
result = gpt4v_formatter(prompt)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
@patch("autogen.img_utils.get_image_data")
|
|
def test_multiple_images(self, mock_get_image_data):
|
|
"""
|
|
Test the gpt4v_formatter function with a prompt containing multiple images.
|
|
"""
|
|
# Mock the get_image_data function to return a fixed string.
|
|
mock_get_image_data.return_value = raw_encoded_image
|
|
|
|
prompt = (
|
|
"This is a test with images <img http://example.com/image1.png> and <img http://example.com/image2.png>."
|
|
)
|
|
expected_output = [
|
|
{"type": "text", "text": "This is a test with images "},
|
|
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
|
{"type": "text", "text": " and "},
|
|
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
|
{"type": "text", "text": "."},
|
|
]
|
|
result = gpt4v_formatter(prompt)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
|
class TestExtractImgPaths(unittest.TestCase):
|
|
def test_no_images(self):
|
|
"""
|
|
Test the extract_img_paths function with a paragraph containing no images.
|
|
"""
|
|
paragraph = "This is a test paragraph with no images."
|
|
expected_output = []
|
|
result = extract_img_paths(paragraph)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
def test_with_images(self):
|
|
"""
|
|
Test the extract_img_paths function with a paragraph containing images.
|
|
"""
|
|
paragraph = (
|
|
"This is a test paragraph with images http://example.com/image1.jpg and http://example.com/image2.png."
|
|
)
|
|
expected_output = ["http://example.com/image1.jpg", "http://example.com/image2.png"]
|
|
result = extract_img_paths(paragraph)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
def test_mixed_case(self):
|
|
"""
|
|
Test the extract_img_paths function with mixed case image extensions.
|
|
"""
|
|
paragraph = "Mixed case extensions http://example.com/image.JPG and http://example.com/image.Png."
|
|
expected_output = ["http://example.com/image.JPG", "http://example.com/image.Png"]
|
|
result = extract_img_paths(paragraph)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
def test_local_paths(self):
|
|
"""
|
|
Test the extract_img_paths function with local file paths.
|
|
"""
|
|
paragraph = "Local paths image1.jpeg and image2.GIF."
|
|
expected_output = ["image1.jpeg", "image2.GIF"]
|
|
result = extract_img_paths(paragraph)
|
|
self.assertEqual(result, expected_output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|