mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Use Gemini without API key (#2805)
* google default auth and svc keyfile for Gemini * [.Net] Release note for 0.0.14 (#2815) * update release note * update trigger * [.Net] Update website for AutoGen.SemanticKernel and AutoGen.Ollama (#2814) support vertex ai compute region * [CAP] User supplied threads for agents (#2812) * First pass: message loop in main thread * pypi version bump * Fix readme * Better example * Fixed docs * pre-commit fixes * refactoring, minor fixes, update gemini demo ipynb * add new deps again and reset line endings * Docstring for the init function. Use private methods * improve docstring --------- Co-authored-by: Xiaoyun Zhang <bigmiao.zhang@gmail.com> Co-authored-by: Rajan <rajan.chari@yahoo.com> Co-authored-by: Zoltan Lux <z.lux@campus.tu-berlin.de>
This commit is contained in:
@@ -42,12 +42,16 @@ from typing import Any, Dict, List, Mapping, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
import vertexai
|
||||
from google.ai.generativelanguage import Content, Part
|
||||
from google.api_core.exceptions import InternalServerError
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from PIL import Image
|
||||
from vertexai.generative_models import Content as VertexAIContent
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
from vertexai.generative_models import Part as VertexAIPart
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
@@ -68,14 +72,49 @@ class GeminiClient:
|
||||
"max_output_tokens": "max_output_tokens",
|
||||
}
|
||||
|
||||
def _initialize_vartexai(self, **params):
|
||||
if "google_application_credentials" in params:
|
||||
# Path to JSON Keyfile
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
|
||||
vertexai_init_args = {}
|
||||
if "project_id" in params:
|
||||
vertexai_init_args["project"] = params["project_id"]
|
||||
if "location" in params:
|
||||
vertexai_init_args["location"] = params["location"]
|
||||
if vertexai_init_args:
|
||||
vertexai.init(**vertexai_init_args)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Uses either either api_key for authentication from the LLM config
|
||||
(specifying the GOOGLE_API_KEY environment variable also works),
|
||||
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
|
||||
where project_id and location can also be passed as parameters. Service account key file can also be used.
|
||||
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
|
||||
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for using Gemini.
|
||||
google_application_credentials (str): Path to the JSON service account key file of the service account.
|
||||
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
||||
can also be set instead of using this argument.
|
||||
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
|
||||
location (str): Compute region to be used, like 'us-west1'.
|
||||
This parameter is only valid in case no API key is specified.
|
||||
"""
|
||||
self.api_key = kwargs.get("api_key", None)
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
assert (
|
||||
self.api_key
|
||||
), "Please provide api_key in your config list entry for Gemini or set the GOOGLE_API_KEY env variable."
|
||||
if self.api_key is None:
|
||||
self.use_vertexai = True
|
||||
self._initialize_vartexai(**kwargs)
|
||||
else:
|
||||
self.use_vertexai = False
|
||||
else:
|
||||
self.use_vertexai = False
|
||||
if not self.use_vertexai:
|
||||
assert ("project_id" not in kwargs) and (
|
||||
"location" not in kwargs
|
||||
), "Google Cloud project and compute location cannot be set when using an API Key!"
|
||||
|
||||
def message_retrieval(self, response) -> List:
|
||||
"""
|
||||
@@ -102,6 +141,12 @@ class GeminiClient:
|
||||
}
|
||||
|
||||
def create(self, params: Dict) -> ChatCompletion:
|
||||
if self.use_vertexai:
|
||||
self._initialize_vartexai(**params)
|
||||
else:
|
||||
assert ("project_id" not in params) and (
|
||||
"location" not in params
|
||||
), "Google Cloud project and compute location cannot be set when using an API Key!"
|
||||
model_name = params.get("model", "gemini-pro")
|
||||
if not model_name:
|
||||
raise ValueError(
|
||||
@@ -133,13 +178,17 @@ class GeminiClient:
|
||||
|
||||
if "vision" not in model_name:
|
||||
# A. create and call the chat model.
|
||||
gemini_messages = oai_messages_to_gemini_messages(messages)
|
||||
|
||||
# we use chat model by default
|
||||
model = genai.GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
gemini_messages = self._oai_messages_to_gemini_messages(messages)
|
||||
if self.use_vertexai:
|
||||
model = GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
)
|
||||
else:
|
||||
# we use chat model by default
|
||||
model = genai.GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
chat = model.start_chat(history=gemini_messages[:-1])
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
@@ -167,14 +216,19 @@ class GeminiClient:
|
||||
completion_tokens = model.count_tokens(ans).total_tokens
|
||||
elif model_name == "gemini-pro-vision":
|
||||
# B. handle the vision model
|
||||
if self.use_vertexai:
|
||||
model = GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
)
|
||||
else:
|
||||
model = genai.GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
# Gemini's vision model does not support chat history yet
|
||||
model = genai.GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
# chat = model.start_chat(history=gemini_messages[:-1])
|
||||
# response = chat.send_message(gemini_messages[-1])
|
||||
user_message = oai_content_to_gemini_content(messages[-1]["content"])
|
||||
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
|
||||
if len(messages) > 2:
|
||||
warnings.warn(
|
||||
"Warning: Gemini's vision model does not support chat history yet.",
|
||||
@@ -184,7 +238,10 @@ class GeminiClient:
|
||||
|
||||
response = model.generate_content(user_message, stream=stream)
|
||||
# ans = response.text
|
||||
ans: str = response._result.candidates[0].content.parts[0].text
|
||||
if self.use_vertexai:
|
||||
ans: str = response.candidates[0].content.parts[0].text
|
||||
else:
|
||||
ans: str = response._result.candidates[0].content.parts[0].text
|
||||
|
||||
prompt_tokens = model.count_tokens(user_message).total_tokens
|
||||
completion_tokens = model.count_tokens(ans).total_tokens
|
||||
@@ -209,99 +266,111 @@ class GeminiClient:
|
||||
|
||||
return response_oai
|
||||
|
||||
def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
|
||||
"""Convert content from OAI format to Gemini format"""
|
||||
rst = []
|
||||
if isinstance(content, str):
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIPart.from_text(content))
|
||||
else:
|
||||
rst.append(Part(text=content))
|
||||
return rst
|
||||
|
||||
def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
||||
if "1.5" in model_name or "gemini-experimental" in model_name:
|
||||
# "gemini-1.5-pro-preview-0409"
|
||||
# Cost is $7 per million input tokens and $21 per million output tokens
|
||||
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
|
||||
assert isinstance(content, list)
|
||||
|
||||
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
|
||||
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
|
||||
|
||||
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
|
||||
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
|
||||
|
||||
|
||||
def oai_content_to_gemini_content(content: Union[str, List]) -> List:
|
||||
"""Convert content from OAI format to Gemini format"""
|
||||
rst = []
|
||||
if isinstance(content, str):
|
||||
rst.append(Part(text=content))
|
||||
for msg in content:
|
||||
if isinstance(msg, dict):
|
||||
assert "type" in msg, f"Missing 'type' field in message: {msg}"
|
||||
if msg["type"] == "text":
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIPart.from_text(text=msg["text"]))
|
||||
else:
|
||||
rst.append(Part(text=msg["text"]))
|
||||
elif msg["type"] == "image_url":
|
||||
if self.use_vertexai:
|
||||
img_url = msg["image_url"]["url"]
|
||||
re.match(r"data:image/(?:png|jpeg);base64,", img_url)
|
||||
img = get_image_data(img_url, use_b64=False)
|
||||
# image/png works with jpeg as well
|
||||
img_part = VertexAIPart.from_data(img, mime_type="image/png")
|
||||
rst.append(img_part)
|
||||
else:
|
||||
b64_img = get_image_data(msg["image_url"]["url"])
|
||||
img = _to_pil(b64_img)
|
||||
rst.append(img)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {msg['type']}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
return rst
|
||||
|
||||
assert isinstance(content, list)
|
||||
def _concat_parts(self, parts: List[Part]) -> List:
|
||||
"""Concatenate parts with the same type.
|
||||
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
|
||||
"""
|
||||
if not parts:
|
||||
return []
|
||||
|
||||
for msg in content:
|
||||
if isinstance(msg, dict):
|
||||
assert "type" in msg, f"Missing 'type' field in message: {msg}"
|
||||
if msg["type"] == "text":
|
||||
rst.append(Part(text=msg["text"]))
|
||||
elif msg["type"] == "image_url":
|
||||
b64_img = get_image_data(msg["image_url"]["url"])
|
||||
img = _to_pil(b64_img)
|
||||
rst.append(img)
|
||||
concatenated_parts = []
|
||||
previous_part = parts[0]
|
||||
|
||||
for current_part in parts[1:]:
|
||||
if previous_part.text != "":
|
||||
if self.use_vertexai:
|
||||
previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
|
||||
else:
|
||||
previous_part.text += current_part.text
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {msg['type']}")
|
||||
concatenated_parts.append(previous_part)
|
||||
previous_part = current_part
|
||||
|
||||
if previous_part.text == "":
|
||||
if self.use_vertexai:
|
||||
previous_part = VertexAIPart.from_text("empty")
|
||||
else:
|
||||
previous_part.text = "empty" # Empty content is not allowed.
|
||||
concatenated_parts.append(previous_part)
|
||||
|
||||
return concatenated_parts
|
||||
|
||||
def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Gemini format.
|
||||
Make sure the "user" role and "model" role are interleaved.
|
||||
Also, make sure the last item is from the "user" role.
|
||||
"""
|
||||
prev_role = None
|
||||
rst = []
|
||||
curr_parts = []
|
||||
for i, message in enumerate(messages):
|
||||
parts = self._oai_content_to_gemini_content(message["content"])
|
||||
role = "user" if message["role"] in ["user", "system"] else "model"
|
||||
|
||||
if prev_role is None or role == prev_role:
|
||||
curr_parts += parts
|
||||
elif role != prev_role:
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role))
|
||||
else:
|
||||
rst.append(Content(parts=curr_parts, role=prev_role))
|
||||
prev_role = role
|
||||
|
||||
# handle the last message
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role))
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
return rst
|
||||
rst.append(Content(parts=curr_parts, role=role))
|
||||
|
||||
# The Gemini is restrict on order of roles, such that
|
||||
# 1. The messages should be interleaved between user and model.
|
||||
# 2. The last message must be from the user role.
|
||||
# We add a dummy message "continue" if the last role is not the user.
|
||||
if rst[-1].role != "user":
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
|
||||
else:
|
||||
rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
|
||||
|
||||
def concat_parts(parts: List[Part]) -> List:
|
||||
"""Concatenate parts with the same type.
|
||||
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
|
||||
"""
|
||||
if not parts:
|
||||
return []
|
||||
|
||||
concatenated_parts = []
|
||||
previous_part = parts[0]
|
||||
|
||||
for current_part in parts[1:]:
|
||||
if previous_part.text != "":
|
||||
previous_part.text += current_part.text
|
||||
else:
|
||||
concatenated_parts.append(previous_part)
|
||||
previous_part = current_part
|
||||
|
||||
if previous_part.text == "":
|
||||
previous_part.text = "empty" # Empty content is not allowed.
|
||||
concatenated_parts.append(previous_part)
|
||||
|
||||
return concatenated_parts
|
||||
|
||||
|
||||
def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Gemini format.
|
||||
Make sure the "user" role and "model" role are interleaved.
|
||||
Also, make sure the last item is from the "user" role.
|
||||
"""
|
||||
prev_role = None
|
||||
rst = []
|
||||
curr_parts = []
|
||||
for i, message in enumerate(messages):
|
||||
parts = oai_content_to_gemini_content(message["content"])
|
||||
role = "user" if message["role"] in ["user", "system"] else "model"
|
||||
|
||||
if prev_role is None or role == prev_role:
|
||||
curr_parts += parts
|
||||
elif role != prev_role:
|
||||
rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
|
||||
curr_parts = parts
|
||||
prev_role = role
|
||||
|
||||
# handle the last message
|
||||
rst.append(Content(parts=concat_parts(curr_parts), role=role))
|
||||
|
||||
# The Gemini is restrict on order of roles, such that
|
||||
# 1. The messages should be interleaved between user and model.
|
||||
# 2. The last message must be from the user role.
|
||||
# We add a dummy message "continue" if the last role is not the user.
|
||||
if rst[-1].role != "user":
|
||||
rst.append(Content(parts=oai_content_to_gemini_content("continue"), role="user"))
|
||||
|
||||
return rst
|
||||
return rst
|
||||
|
||||
|
||||
def _to_pil(data: str) -> Image.Image:
|
||||
@@ -336,3 +405,16 @@ def get_image_data(image_file: str, use_b64=True) -> bytes:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
||||
if "1.5" in model_name or "gemini-experimental" in model_name:
|
||||
# "gemini-1.5-pro-preview-0409"
|
||||
# Cost is $7 per million input tokens and $21 per million output tokens
|
||||
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
|
||||
|
||||
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
|
||||
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
|
||||
|
||||
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
|
||||
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
|
||||
|
||||
2
setup.py
2
setup.py
@@ -80,7 +80,7 @@ extra_require = {
|
||||
"teachable": ["chromadb"],
|
||||
"lmm": ["replicate", "pillow"],
|
||||
"graph": ["networkx", "matplotlib"],
|
||||
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
|
||||
"gemini": ["google-generativeai>=0.5,<1", "google-cloud-aiplatform", "google-auth", "pillow", "pydantic"],
|
||||
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
|
||||
"redis": ["redis"],
|
||||
"cosmosdb": ["azure-cosmos>=4.2.0"],
|
||||
|
||||
@@ -33,11 +33,18 @@ def gemini_client():
|
||||
return GeminiClient(api_key="fake_api_key")
|
||||
|
||||
|
||||
# Test initialization and configuration
|
||||
# Test compute location initialization and configuration
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_initialization():
|
||||
def test_compute_location_initialization():
|
||||
with pytest.raises(AssertionError):
|
||||
GeminiClient() # Should raise an AssertionError due to missing API key
|
||||
GeminiClient(
|
||||
api_key="fake_api_key", location="us-west1"
|
||||
) # Should raise an AssertionError due to specifying API key and compute location
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_google_auth_default_client():
|
||||
return GeminiClient()
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
|
||||
@@ -24,11 +24,13 @@
|
||||
"\n",
|
||||
"## Features\n",
|
||||
"\n",
|
||||
"There's no need to handle OpenAI or Google's GenAI packages separately; AutoGen manages all of these for you. You can easily create different agents with various backend LLMs using the assistant agent. All models and agents are readily accessible at your fingertips.\n",
|
||||
"There's no need to handle OpenAI or Google's GenAI packages separately; AutoGen manages all of these for you. You can easily create different agents with various backend LLMs using the assistant agent. All models and agents are readily accessible at your fingertips. \n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Main Distinctions\n",
|
||||
"\n",
|
||||
"- Currently, Gemini does not include a \"system_message\" field. However, you can incorporate this instruction into the first message of your interaction."
|
||||
"- Currently, Gemini does not include a \"system_message\" field. However, you can incorporate this instruction into the first message of your interaction.\n",
|
||||
"- If no API key is specified for Gemini, then authentication will happen using the default google auth mechanism for Google Cloud. Service accounts are also supported, where the JSON key file has to be provided."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -57,6 +59,16 @@
|
||||
" \"api_type\": \"google\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"model\": \"gemini-1.5-pro-001\",\n",
|
||||
" \"api_type\": \"google\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"model\": \"gemini-1.5-pro\",\n",
|
||||
" \"project\": \"your-awesome-google-cloud-project-id\",\n",
|
||||
" \"location\": \"us-west1\",\n",
|
||||
" \"google_application_credentials\": \"your-google-service-account-key.json\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"model\": \"gemini-pro-vision\",\n",
|
||||
" \"api_key\": \"your Google's GenAI Key goes here\",\n",
|
||||
" \"api_type\": \"google\"\n",
|
||||
@@ -110,7 +122,7 @@
|
||||
"config_list_gemini = autogen.config_list_from_json(\n",
|
||||
" \"OAI_CONFIG_LIST\",\n",
|
||||
" filter_dict={\n",
|
||||
" \"model\": [\"gemini-pro\"],\n",
|
||||
" \"model\": [\"gemini-pro\", \"gemini-1.5-pro\", \"gemini-1.5-pro-001\"],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user