feat(forge/llm): Add GroqProvider (#7130)

* Add `GroqProvider` in `forge.llm.providers.groq`
  * Add to `llm.providers.multi`
  * Add `groq` dependency (v0.8.0)

* Update AutoGPT docs & config template
  * Update .env.template
  * Update docs
This commit is contained in:
Reinier van der Leer
2024-05-24 16:34:51 +02:00
committed by GitHub
parent cdae98d36b
commit edcbbbce25
8 changed files with 479 additions and 8 deletions

View File

@@ -8,6 +8,9 @@
## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
# ANTHROPIC_API_KEY=
## GROQ_API_KEY - Groq API Key (Example: gsk_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
# GROQ_API_KEY=
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
# TELEMETRY_OPT_IN=true

23
autogpt/poetry.lock generated
View File

@@ -330,6 +330,7 @@ fastapi = "^0.109.1"
gitpython = "^3.1.32"
google-api-python-client = "*"
google-cloud-storage = "^2.13.0"
groq = "^0.8.0"
jinja2 = "^3.1.2"
jsonschema = "*"
litellm = "^1.17.9"
@@ -355,7 +356,7 @@ uvicorn = "^0.23.2"
webdriver-manager = "^4.0.1"
[package.extras]
benchmark = ["agbenchmark @ file:///Users/czerwinski/Projects/AutoGPT/benchmark"]
benchmark = ["agbenchmark @ file:///home/reinier/code/agpt/Auto-GPT/benchmark"]
[package.source]
type = "directory"
@@ -2444,6 +2445,25 @@ files = [
docs = ["Sphinx", "furo"]
test = ["objgraph", "psutil"]
[[package]]
name = "groq"
version = "0.8.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.7"
files = [
{file = "groq-0.8.0-py3-none-any.whl", hash = "sha256:f5e4e892d45001241a930db451e633ca1f0007e3f749deaa5d7360062fcd61e3"},
{file = "groq-0.8.0.tar.gz", hash = "sha256:37ceb2f706bd516d0bfcac8e89048a24b375172987a0d6bd9efb521c54f6deff"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.7,<5"
[[package]]
name = "grpc-google-iam-v1"
version = "0.13.0"
@@ -5255,7 +5275,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},

View File

@@ -22,6 +22,7 @@ Configuration is controlled through the `Config` object. You can set configurati
- `GITHUB_USERNAME`: GitHub Username. Optional.
- `GOOGLE_API_KEY`: Google API key. Optional.
- `GOOGLE_CUSTOM_SEARCH_ENGINE_ID`: [Google custom search engine ID](https://programmablesearchengine.google.com/controlpanel/all). Optional.
- `GROQ_API_KEY`: Set this if you want to use Groq models with AutoGPT
- `HEADLESS_BROWSER`: Use a headless browser while AutoGPT uses a web browser. Setting to `False` will allow you to see AutoGPT operate the browser. Default: True
- `HUGGINGFACE_API_TOKEN`: HuggingFace API, to be used for both image generation and audio to text. Optional.
- `HUGGINGFACE_AUDIO_TO_TEXT_MODEL`: HuggingFace audio to text model. Default: CompVis/stable-diffusion-v1-4

View File

@@ -0,0 +1,426 @@
from __future__ import annotations
import enum
import logging
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
import sentry_sdk
import tenacity
import tiktoken
from groq import APIConnectionError, APIStatusError
from pydantic import SecretStr
from forge.json.parsing import json_loads
from forge.llm.providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
)
from forge.models.config import Configurable, UserConfigurable
from .utils import validate_tool_calls
if TYPE_CHECKING:
from groq.types.chat import ChatCompletion, CompletionCreateParams
from groq.types.chat.chat_completion_message import ChatCompletionMessage
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
_T = TypeVar("_T")
_P = ParamSpec("_P")
class GroqModelName(str, enum.Enum):
LLAMA3_8B = "llama3-8b-8192"
LLAMA3_70B = "llama3-70b-8192"
MIXTRAL_8X7B = "mixtral-8x7b-32768"
GEMMA_7B = "gemma-7b-it"
GROQ_CHAT_MODELS = {
info.name: info
for info in [
ChatModelInfo(
name=GroqModelName.LLAMA3_8B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.05 / 1e6,
completion_token_cost=0.10 / 1e6,
max_tokens=8192,
has_function_call_api=True,
),
ChatModelInfo(
name=GroqModelName.LLAMA3_70B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.59 / 1e6,
completion_token_cost=0.79 / 1e6,
max_tokens=8192,
has_function_call_api=True,
),
ChatModelInfo(
name=GroqModelName.MIXTRAL_8X7B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.27 / 1e6,
completion_token_cost=0.27 / 1e6,
max_tokens=32768,
has_function_call_api=True,
),
ChatModelInfo(
name=GroqModelName.GEMMA_7B,
provider_name=ModelProviderName.GROQ,
prompt_token_cost=0.10 / 1e6,
completion_token_cost=0.10 / 1e6,
max_tokens=8192,
has_function_call_api=True,
),
]
}
class GroqConfiguration(ModelProviderConfiguration):
fix_failed_parse_tries: int = UserConfigurable(3)
class GroqCredentials(ModelProviderCredentials):
"""Credentials for Groq."""
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY")
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="GROQ_API_BASE_URL"
)
def get_api_access_kwargs(self) -> dict[str, str]:
return {
k: (v.get_secret_value() if type(v) is SecretStr else v)
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
}.items()
if v is not None
}
class GroqSettings(ModelProviderSettings):
configuration: GroqConfiguration
credentials: Optional[GroqCredentials]
budget: ModelProviderBudget
class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
default_settings = GroqSettings(
name="groq_provider",
description="Provides access to Groq's API.",
configuration=GroqConfiguration(
retries_per_request=7,
),
credentials=None,
budget=ModelProviderBudget(),
)
_settings: GroqSettings
_configuration: GroqConfiguration
_credentials: GroqCredentials
_budget: ModelProviderBudget
def __init__(
self,
settings: Optional[GroqSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not settings:
settings = self.default_settings.copy(deep=True)
if not settings.credentials:
settings.credentials = GroqCredentials.from_env()
super(GroqProvider, self).__init__(settings=settings, logger=logger)
from groq import AsyncGroq
self._client = AsyncGroq(**self._credentials.get_api_access_kwargs())
async def get_available_models(self) -> list[ChatModelInfo]:
_models = (await self._client.models.list()).data
return [GROQ_CHAT_MODELS[m.id] for m in _models if m.id in GROQ_CHAT_MODELS]
def get_token_limit(self, model_name: str) -> int:
"""Get the token limit for a given model."""
return GROQ_CHAT_MODELS[model_name].max_tokens
@classmethod
def get_tokenizer(cls, model_name: GroqModelName) -> ModelTokenizer:
# HACK: No official tokenizer is available for Groq
return tiktoken.encoding_for_model("gpt-3.5-turbo")
@classmethod
def count_tokens(cls, text: str, model_name: GroqModelName) -> int:
return len(cls.get_tokenizer(model_name).encode(text))
@classmethod
def count_message_tokens(
cls,
messages: ChatMessage | list[ChatMessage],
model_name: GroqModelName,
) -> int:
if isinstance(messages, ChatMessage):
messages = [messages]
# HACK: No official tokenizer (for text or messages) is available for Groq.
# Token overhead of messages is unknown and may be inaccurate.
return cls.count_tokens(
"\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name
)
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: GroqModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the Groq API."""
groq_messages, completion_kwargs = self._get_chat_completion_args(
prompt_messages=model_prompt,
model=model_name,
functions=functions,
max_output_tokens=max_output_tokens,
**kwargs,
)
total_cost = 0.0
attempts = 0
while True:
completion_kwargs["messages"] = groq_messages.copy()
_response, _cost, t_input, t_output = await self._create_chat_completion(
completion_kwargs
)
total_cost += _cost
# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
attempts += 1
parse_errors: list[Exception] = []
_assistant_msg = _response.choices[0].message
tool_calls, _errors = self._parse_assistant_tool_calls(_assistant_msg)
parse_errors += _errors
# Validate tool calls
if not parse_errors and tool_calls and functions:
parse_errors += validate_tool_calls(tool_calls, functions)
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=tool_calls or None,
)
parsed_result: _T = None # type: ignore
if not parse_errors:
try:
parsed_result = completion_parser(assistant_msg)
except Exception as e:
parse_errors.append(e)
if not parse_errors:
if attempts > 1:
self._logger.debug(
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
)
return ChatModelResponse(
response=AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=tool_calls or None,
),
parsed_result=parsed_result,
model_info=GROQ_CHAT_MODELS[model_name],
prompt_tokens_used=t_input,
completion_tokens_used=t_output,
)
else:
self._logger.debug(
f"Parsing failed on response: '''{_assistant_msg}'''"
)
parse_errors_fmt = "\n\n".join(
f"{e.__class__.__name__}: {e}" for e in parse_errors
)
self._logger.warning(
f"Parsing attempt #{attempts} failed: {parse_errors_fmt}"
)
for e in parse_errors:
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
)
if attempts < self._configuration.fix_failed_parse_tries:
groq_messages.append(_assistant_msg.dict(exclude_none=True))
groq_messages.append(
{
"role": "system",
"content": (
f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}"
),
}
)
continue
else:
raise parse_errors[0]
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
model: GroqModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs, # type: ignore
) -> tuple[list[ChatCompletionMessageParam], CompletionCreateParams]:
"""Prepare chat completion arguments and keyword arguments for API call.
Args:
model_prompt: List of ChatMessages.
model_name: The model to use.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
Returns:
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
dict[str, Any]: Any other kwargs for the OpenAI call
"""
kwargs: CompletionCreateParams = kwargs # type: ignore
kwargs["model"] = model
if max_output_tokens:
kwargs["max_tokens"] = max_output_tokens
if functions:
kwargs["tools"] = [
{"type": "function", "function": f.schema} for f in functions
]
if len(functions) == 1:
# force the model to call the only specified function
kwargs["tool_choice"] = {
"type": "function",
"function": {"name": functions[0].name},
}
if extra_headers := self._configuration.extra_request_headers:
# 'extra_headers' is not on CompletionCreateParams, but is on chat.create()
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
groq_messages: list[ChatCompletionMessageParam] = [
message.dict(
include={"role", "content", "tool_calls", "tool_call_id", "name"},
exclude_none=True,
)
for message in prompt_messages
]
if "messages" in kwargs:
groq_messages += kwargs["messages"]
del kwargs["messages"] # type: ignore - messages are added back later
return groq_messages, kwargs
async def _create_chat_completion(
self, completion_kwargs: CompletionCreateParams
) -> tuple[ChatCompletion, float, int, int]:
"""
Create a chat completion using the Groq API with retry handling.
Params:
completion_kwargs: Keyword arguments for an Groq Messages API call
Returns:
Message: The message completion object
float: The cost ($) of this completion
int: Number of input tokens used
int: Number of output tokens used
"""
@self._retry_api_request
async def _create_chat_completion_with_retry(
completion_kwargs: CompletionCreateParams,
) -> ChatCompletion:
return await self._client.chat.completions.create(**completion_kwargs)
response = await _create_chat_completion_with_retry(completion_kwargs)
cost = self._budget.update_usage_and_cost(
model_info=GROQ_CHAT_MODELS[completion_kwargs["model"]],
input_tokens_used=response.usage.prompt_tokens,
output_tokens_used=response.usage.completion_tokens,
)
return (
response,
cost,
response.usage.prompt_tokens,
response.usage.completion_tokens,
)
def _parse_assistant_tool_calls(
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
):
tool_calls: list[AssistantToolCall] = []
parse_errors: list[Exception] = []
if assistant_message.tool_calls:
for _tc in assistant_message.tool_calls:
try:
parsed_arguments = json_loads(_tc.function.arguments)
except Exception as e:
err_message = (
f"Decoding arguments for {_tc.function.name} failed: "
+ str(e.args[0])
)
parse_errors.append(
type(e)(err_message, *e.args[1:]).with_traceback(
e.__traceback__
)
)
continue
tool_calls.append(
AssistantToolCall(
id=_tc.id,
type=_tc.type,
function=AssistantFunctionCall(
name=_tc.function.name,
arguments=parsed_arguments,
),
)
)
# If parsing of all tool calls succeeds in the end, we ignore any issues
if len(tool_calls) == len(assistant_message.tool_calls):
parse_errors = []
return tool_calls, parse_errors
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
return tenacity.retry(
retry=(
tenacity.retry_if_exception_type(APIConnectionError)
| tenacity.retry_if_exception(
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
)
),
wait=tenacity.wait_exponential(),
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
after=tenacity.after_log(self._logger, logging.DEBUG),
)(func)
def __repr__(self):
return "GroqProvider()"

View File

@@ -8,6 +8,7 @@ from pydantic import ValidationError
from forge.models.config import Configurable
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
from .schema import (
AssistantChatMessage,
@@ -25,9 +26,9 @@ from .schema import (
_T = TypeVar("_T")
ModelName = AnthropicModelName | OpenAIModelName
ModelName = AnthropicModelName | GroqModelName | OpenAIModelName
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
@@ -143,16 +144,17 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
@classmethod
def _get_model_provider_class(
cls, model_name: ModelName
) -> type[AnthropicProvider | OpenAIProvider]:
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
@classmethod
def _get_provider_class(
cls, provider_name: ModelProviderName
) -> type[AnthropicProvider | OpenAIProvider]:
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
try:
return {
ModelProviderName.ANTHROPIC: AnthropicProvider,
ModelProviderName.GROQ: GroqProvider,
ModelProviderName.OPENAI: OpenAIProvider,
}[provider_name]
except KeyError:

View File

@@ -45,6 +45,7 @@ class ModelProviderService(str, enum.Enum):
class ModelProviderName(str, enum.Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
class ChatMessage(BaseModel):

22
forge/poetry.lock generated
View File

@@ -1762,6 +1762,25 @@ files = [
docs = ["Sphinx", "furo"]
test = ["objgraph", "psutil"]
[[package]]
name = "groq"
version = "0.8.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.7"
files = [
{file = "groq-0.8.0-py3-none-any.whl", hash = "sha256:f5e4e892d45001241a930db451e633ca1f0007e3f749deaa5d7360062fcd61e3"},
{file = "groq-0.8.0.tar.gz", hash = "sha256:37ceb2f706bd516d0bfcac8e89048a24b375172987a0d6bd9efb521c54f6deff"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.7,<5"
[[package]]
name = "grpcio"
version = "1.60.0"
@@ -4277,7 +4296,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -6117,4 +6135,4 @@ benchmark = ["agbenchmark"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "5e493a1132ffacec11607b5ca1c91147d5b43c47af6c1ac039a235b3a68fcb1e"
content-hash = "19547e2de5ebeda0d3ef171b5dda23aecb6003c48f633ef09367095463206c3f"

View File

@@ -27,6 +27,7 @@ fastapi = "^0.109.1"
gitpython = "^3.1.32"
google-api-python-client = "*"
google-cloud-storage = "^2.13.0"
groq = "^0.8.0"
jinja2 = "^3.1.2"
jsonschema = "*"
litellm = "^1.17.9"