Files
autogen/test/oai/test_gemini.py
Beibin Li 0aaf30a8da Merge "Gemini" feature into the main branch (#2360)
* Start Gemini integration: works ok with Text now

* Gemini notebook lint

* try catch "import" for Gemini

* Debug: id issue for chat completion in Gemini

* Add RAG example

* Update docs for RAG

* Fix missing pydash

* Remove temp folder

* Fix test error in runs/7206014032/job/19630042864

* Fix tqdm warning

* Fix notebook output

* Gemini's vision model is supported now

* Install instructions for the Gemini branch

* Catch and retry when see Interval Server Error 500

* Allow gemini to take more flexible messages
i.e., it can take messages where "user" is not the last role.

* Use int time for Gemini client

* Handle other exceptions in gemini call

* rename to "create" function for gemini

* GeminiClient compatible with ModelClient now

* Lint

* Update instructions in Gemini notebook

* Lint

* Remove empty blocks from Gemini notebook

* Add gemini into example page

* self.create instead of call

* Add py and Py into python execution

* Remove error code from merging

* Remove pydash dependency for gemini

* Add cloud-gemini doc

* Remove temp file

* cache import update

* Add test case for summary with mm input

* Lint: warnings instead of print

* Add test cases for gemini

* Gemini test config

* Disable default model for gemini

* Typo fix in gemini workflow

* Correct grammar in example notebook

* Raise if "model" is not provided in create(...)

* Move TODOs into a roadmap

* Update .github/workflows/contrib-tests.yml

Co-authored-by: Davor Runje <davor@airt.ai>

* Gemini test config update

* Update setup.py

Co-authored-by: Davor Runje <davor@airt.ai>

* Update test/oai/test_gemini.py

Co-authored-by: Davor Runje <davor@airt.ai>

* Update test/oai/test_gemini.py

Co-authored-by: Davor Runje <davor@airt.ai>

* Remove python 3.8 from gemini
No google's generativeai for Windows with Python 3.8

* Update import error handling for gemini

* Count tokens and cost for gemini

---------

Co-authored-by: Li Jiang <bnujli@gmail.com>
Co-authored-by: Davor Runje <davor@airt.ai>
2024-04-17 00:24:07 +00:00

149 lines
5.6 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
try:
from google.api_core.exceptions import InternalServerError
from autogen.oai.gemini import GeminiClient
skip = False
except ImportError:
GeminiClient = object
InternalServerError = object
skip = True
# Fixtures for mock data
@pytest.fixture
def mock_response():
class MockResponse:
def __init__(self, text, choices, usage, cost, model):
self.text = text
self.choices = choices
self.usage = usage
self.cost = cost
self.model = model
return MockResponse
@pytest.fixture
def gemini_client():
return GeminiClient(api_key="fake_api_key")
# Test initialization and configuration
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_initialization():
with pytest.raises(AssertionError):
GeminiClient() # Should raise an AssertionError due to missing API key
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_internal_server_error_retry(mock_genai, gemini_client):
mock_genai.GenerativeModel.side_effect = [InternalServerError("Test Error"), None] # First call fails
# Mock successful response
mock_chat = MagicMock()
mock_chat.send_message.return_value = "Successful response"
mock_genai.GenerativeModel.return_value.start_chat.return_value = mock_chat
with patch.object(gemini_client, "create", return_value="Retried Successfully"):
response = gemini_client.create({"model": "gemini-pro", "messages": [{"content": "Hello"}]})
assert response == "Retried Successfully", "Should retry on InternalServerError"
# Test cost calculation
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_cost_calculation(gemini_client, mock_response):
response = mock_response(
text="Example response",
choices=[{"message": "Test message 1"}],
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
cost=0.01,
model="gemini-pro",
)
assert gemini_client.cost(response) > 0, "Cost should be correctly calculated as zero"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
def test_create_response(mock_configure, mock_generative_model, gemini_client):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_configure.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
# Call the create method
response = gemini_client.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
def test_create_vision_model_response(mock_configure, mock_generative_model, gemini_client):
# Mock the genai model configuration and creation process
mock_model = MagicMock()
mock_configure.return_value = None
mock_generative_model.return_value = mock_model
# Set up a mock to simulate the vision model behavior
mock_vision_response = MagicMock()
mock_vision_part = MagicMock(text="Vision model output")
# Setting up the chain of return values for vision model response
mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = (
mock_vision_part
)
mock_model.generate_content.return_value = mock_vision_response
# Call the create method with vision model parameters
response = gemini_client.create(
{
"model": "gemini-pro-vision", # Vision model name
"messages": [
{
"content": [
{"type": "text", "text": "Let's play a game."},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
},
},
],
"role": "user",
}
], # Assuming a simple content input for vision
"stream": False,
}
)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.content == "Vision model output"
), "Response content should match expected output from vision model"