mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-08 22:58:01 -05:00
feat(forge/llm): Add GroqProvider (#7130)
* Add `GroqProvider` in `forge.llm.providers.groq` * Add to `llm.providers.multi` * Add `groq` dependency (v0.8.0) * Update AutoGPT docs & config template * Update .env.template * Update docs
This commit is contained in:
committed by
GitHub
parent
cdae98d36b
commit
edcbbbce25
@@ -8,6 +8,9 @@
|
||||
## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# ANTHROPIC_API_KEY=
|
||||
|
||||
## GROQ_API_KEY - Groq API Key (Example: gsk_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# GROQ_API_KEY=
|
||||
|
||||
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
|
||||
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
|
||||
# TELEMETRY_OPT_IN=true
|
||||
|
||||
23
autogpt/poetry.lock
generated
23
autogpt/poetry.lock
generated
@@ -330,6 +330,7 @@ fastapi = "^0.109.1"
|
||||
gitpython = "^3.1.32"
|
||||
google-api-python-client = "*"
|
||||
google-cloud-storage = "^2.13.0"
|
||||
groq = "^0.8.0"
|
||||
jinja2 = "^3.1.2"
|
||||
jsonschema = "*"
|
||||
litellm = "^1.17.9"
|
||||
@@ -355,7 +356,7 @@ uvicorn = "^0.23.2"
|
||||
webdriver-manager = "^4.0.1"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///Users/czerwinski/Projects/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark @ file:///home/reinier/code/agpt/Auto-GPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@@ -2444,6 +2445,25 @@ files = [
|
||||
docs = ["Sphinx", "furo"]
|
||||
test = ["objgraph", "psutil"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
version = "0.8.0"
|
||||
description = "The official Python library for the groq API"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "groq-0.8.0-py3-none-any.whl", hash = "sha256:f5e4e892d45001241a930db451e633ca1f0007e3f749deaa5d7360062fcd61e3"},
|
||||
{file = "groq-0.8.0.tar.gz", hash = "sha256:37ceb2f706bd516d0bfcac8e89048a24b375172987a0d6bd9efb521c54f6deff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
typing-extensions = ">=4.7,<5"
|
||||
|
||||
[[package]]
|
||||
name = "grpc-google-iam-v1"
|
||||
version = "0.13.0"
|
||||
@@ -5255,7 +5275,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
||||
@@ -22,6 +22,7 @@ Configuration is controlled through the `Config` object. You can set configurati
|
||||
- `GITHUB_USERNAME`: GitHub Username. Optional.
|
||||
- `GOOGLE_API_KEY`: Google API key. Optional.
|
||||
- `GOOGLE_CUSTOM_SEARCH_ENGINE_ID`: [Google custom search engine ID](https://programmablesearchengine.google.com/controlpanel/all). Optional.
|
||||
- `GROQ_API_KEY`: Set this if you want to use Groq models with AutoGPT
|
||||
- `HEADLESS_BROWSER`: Use a headless browser while AutoGPT uses a web browser. Setting to `False` will allow you to see AutoGPT operate the browser. Default: True
|
||||
- `HUGGINGFACE_API_TOKEN`: HuggingFace API, to be used for both image generation and audio to text. Optional.
|
||||
- `HUGGINGFACE_AUDIO_TO_TEXT_MODEL`: HuggingFace audio to text model. Default: CompVis/stable-diffusion-v1-4
|
||||
|
||||
426
forge/forge/llm/providers/groq.py
Normal file
426
forge/forge/llm/providers/groq.py
Normal file
@@ -0,0 +1,426 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
import tiktoken
|
||||
from groq import APIConnectionError, APIStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.json.parsing import json_loads
|
||||
from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from forge.models.config import Configurable, UserConfigurable
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from groq.types.chat import ChatCompletion, CompletionCreateParams
|
||||
from groq.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class GroqModelName(str, enum.Enum):
|
||||
LLAMA3_8B = "llama3-8b-8192"
|
||||
LLAMA3_70B = "llama3-70b-8192"
|
||||
MIXTRAL_8X7B = "mixtral-8x7b-32768"
|
||||
GEMMA_7B = "gemma-7b-it"
|
||||
|
||||
|
||||
GROQ_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.LLAMA3_8B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.05 / 1e6,
|
||||
completion_token_cost=0.10 / 1e6,
|
||||
max_tokens=8192,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.LLAMA3_70B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.59 / 1e6,
|
||||
completion_token_cost=0.79 / 1e6,
|
||||
max_tokens=8192,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.MIXTRAL_8X7B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.27 / 1e6,
|
||||
completion_token_cost=0.27 / 1e6,
|
||||
max_tokens=32768,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=GroqModelName.GEMMA_7B,
|
||||
provider_name=ModelProviderName.GROQ,
|
||||
prompt_token_cost=0.10 / 1e6,
|
||||
completion_token_cost=0.10 / 1e6,
|
||||
max_tokens=8192,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class GroqConfiguration(ModelProviderConfiguration):
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
|
||||
|
||||
class GroqCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Groq."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="GROQ_API_KEY")
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="GROQ_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
class GroqSettings(ModelProviderSettings):
|
||||
configuration: GroqConfiguration
|
||||
credentials: Optional[GroqCredentials]
|
||||
budget: ModelProviderBudget
|
||||
|
||||
|
||||
class GroqProvider(Configurable[GroqSettings], ChatModelProvider):
|
||||
default_settings = GroqSettings(
|
||||
name="groq_provider",
|
||||
description="Provides access to Groq's API.",
|
||||
configuration=GroqConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: GroqSettings
|
||||
_configuration: GroqConfiguration
|
||||
_credentials: GroqCredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[GroqSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = GroqCredentials.from_env()
|
||||
|
||||
super(GroqProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
from groq import AsyncGroq
|
||||
|
||||
self._client = AsyncGroq(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [GROQ_CHAT_MODELS[m.id] for m in _models if m.id in GROQ_CHAT_MODELS]
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return GROQ_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: GroqModelName) -> ModelTokenizer:
|
||||
# HACK: No official tokenizer is available for Groq
|
||||
return tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: GroqModelName) -> int:
|
||||
return len(cls.get_tokenizer(model_name).encode(text))
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: GroqModelName,
|
||||
) -> int:
|
||||
if isinstance(messages, ChatMessage):
|
||||
messages = [messages]
|
||||
# HACK: No official tokenizer (for text or messages) is available for Groq.
|
||||
# Token overhead of messages is unknown and may be inaccurate.
|
||||
return cls.count_tokens(
|
||||
"\n\n".join(f"{m.role.upper()}: {m.content}" for m in messages), model_name
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: GroqModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Groq API."""
|
||||
groq_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
completion_kwargs["messages"] = groq_messages.copy()
|
||||
_response, _cost, t_input, t_output = await self._create_chat_completion(
|
||||
completion_kwargs
|
||||
)
|
||||
total_cost += _cost
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
attempts += 1
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
_assistant_msg = _response.choices[0].message
|
||||
|
||||
tool_calls, _errors = self._parse_assistant_tool_calls(_assistant_msg)
|
||||
parse_errors += _errors
|
||||
|
||||
# Validate tool calls
|
||||
if not parse_errors and tool_calls and functions:
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
parsed_result: _T = None # type: ignore
|
||||
if not parse_errors:
|
||||
try:
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
if not parse_errors:
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
),
|
||||
parsed_result=parsed_result,
|
||||
model_info=GROQ_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
else:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
parse_errors_fmt = "\n\n".join(
|
||||
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||
)
|
||||
self._logger.warning(
|
||||
f"Parsing attempt #{attempts} failed: {parse_errors_fmt}"
|
||||
)
|
||||
for e in parse_errors:
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
groq_messages.append(_assistant_msg.dict(exclude_none=True))
|
||||
groq_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}"
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise parse_errors[0]
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: GroqModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> tuple[list[ChatCompletionMessageParam], CompletionCreateParams]:
|
||||
"""Prepare chat completion arguments and keyword arguments for API call.
|
||||
|
||||
Args:
|
||||
model_prompt: List of ChatMessages.
|
||||
model_name: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
|
||||
dict[str, Any]: Any other kwargs for the OpenAI call
|
||||
"""
|
||||
kwargs: CompletionCreateParams = kwargs # type: ignore
|
||||
kwargs["model"] = model
|
||||
if max_output_tokens:
|
||||
kwargs["max_tokens"] = max_output_tokens
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{"type": "function", "function": f.schema} for f in functions
|
||||
]
|
||||
if len(functions) == 1:
|
||||
# force the model to call the only specified function
|
||||
kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": functions[0].name},
|
||||
}
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
# 'extra_headers' is not on CompletionCreateParams, but is on chat.create()
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) # type: ignore
|
||||
kwargs["extra_headers"].update(extra_headers.copy()) # type: ignore
|
||||
|
||||
groq_messages: list[ChatCompletionMessageParam] = [
|
||||
message.dict(
|
||||
include={"role", "content", "tool_calls", "tool_call_id", "name"},
|
||||
exclude_none=True,
|
||||
)
|
||||
for message in prompt_messages
|
||||
]
|
||||
|
||||
if "messages" in kwargs:
|
||||
groq_messages += kwargs["messages"]
|
||||
del kwargs["messages"] # type: ignore - messages are added back later
|
||||
|
||||
return groq_messages, kwargs
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, completion_kwargs: CompletionCreateParams
|
||||
) -> tuple[ChatCompletion, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Groq API with retry handling.
|
||||
|
||||
Params:
|
||||
completion_kwargs: Keyword arguments for an Groq Messages API call
|
||||
|
||||
Returns:
|
||||
Message: The message completion object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of input tokens used
|
||||
int: Number of output tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
completion_kwargs: CompletionCreateParams,
|
||||
) -> ChatCompletion:
|
||||
return await self._client.chat.completions.create(**completion_kwargs)
|
||||
|
||||
response = await _create_chat_completion_with_retry(completion_kwargs)
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=GROQ_CHAT_MODELS[completion_kwargs["model"]],
|
||||
input_tokens_used=response.usage.prompt_tokens,
|
||||
output_tokens_used=response.usage.completion_tokens,
|
||||
)
|
||||
return (
|
||||
response,
|
||||
cost,
|
||||
response.usage.prompt_tokens,
|
||||
response.usage.completion_tokens,
|
||||
)
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
|
||||
):
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if assistant_message.tool_calls:
|
||||
for _tc in assistant_message.tool_calls:
|
||||
try:
|
||||
parsed_arguments = json_loads(_tc.function.arguments)
|
||||
except Exception as e:
|
||||
err_message = (
|
||||
f"Decoding arguments for {_tc.function.name} failed: "
|
||||
+ str(e.args[0])
|
||||
)
|
||||
parse_errors.append(
|
||||
type(e)(err_message, *e.args[1:]).with_traceback(
|
||||
e.__traceback__
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
AssistantToolCall(
|
||||
id=_tc.id,
|
||||
type=_tc.type,
|
||||
function=AssistantFunctionCall(
|
||||
name=_tc.function.name,
|
||||
arguments=parsed_arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# If parsing of all tool calls succeeds in the end, we ignore any issues
|
||||
if len(tool_calls) == len(assistant_message.tool_calls):
|
||||
parse_errors = []
|
||||
|
||||
return tool_calls, parse_errors
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(APIConnectionError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=tenacity.after_log(self._logger, logging.DEBUG),
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "GroqProvider()"
|
||||
@@ -8,6 +8,7 @@ from pydantic import ValidationError
|
||||
from forge.models.config import Configurable
|
||||
|
||||
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
|
||||
from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider
|
||||
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
@@ -25,9 +26,9 @@ from .schema import (
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ModelName = AnthropicModelName | OpenAIModelName
|
||||
ModelName = AnthropicModelName | GroqModelName | OpenAIModelName
|
||||
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
|
||||
|
||||
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
@@ -143,16 +144,17 @@ class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
@classmethod
|
||||
def _get_model_provider_class(
|
||||
cls, model_name: ModelName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
|
||||
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
|
||||
|
||||
@classmethod
|
||||
def _get_provider_class(
|
||||
cls, provider_name: ModelProviderName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
) -> type[AnthropicProvider | GroqProvider | OpenAIProvider]:
|
||||
try:
|
||||
return {
|
||||
ModelProviderName.ANTHROPIC: AnthropicProvider,
|
||||
ModelProviderName.GROQ: GroqProvider,
|
||||
ModelProviderName.OPENAI: OpenAIProvider,
|
||||
}[provider_name]
|
||||
except KeyError:
|
||||
|
||||
@@ -45,6 +45,7 @@ class ModelProviderService(str, enum.Enum):
|
||||
class ModelProviderName(str, enum.Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
||||
22
forge/poetry.lock
generated
22
forge/poetry.lock
generated
@@ -1762,6 +1762,25 @@ files = [
|
||||
docs = ["Sphinx", "furo"]
|
||||
test = ["objgraph", "psutil"]
|
||||
|
||||
[[package]]
|
||||
name = "groq"
|
||||
version = "0.8.0"
|
||||
description = "The official Python library for the groq API"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "groq-0.8.0-py3-none-any.whl", hash = "sha256:f5e4e892d45001241a930db451e633ca1f0007e3f749deaa5d7360062fcd61e3"},
|
||||
{file = "groq-0.8.0.tar.gz", hash = "sha256:37ceb2f706bd516d0bfcac8e89048a24b375172987a0d6bd9efb521c54f6deff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
typing-extensions = ">=4.7,<5"
|
||||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.60.0"
|
||||
@@ -4277,7 +4296,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@@ -6117,4 +6135,4 @@ benchmark = ["agbenchmark"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "5e493a1132ffacec11607b5ca1c91147d5b43c47af6c1ac039a235b3a68fcb1e"
|
||||
content-hash = "19547e2de5ebeda0d3ef171b5dda23aecb6003c48f633ef09367095463206c3f"
|
||||
|
||||
@@ -27,6 +27,7 @@ fastapi = "^0.109.1"
|
||||
gitpython = "^3.1.32"
|
||||
google-api-python-client = "*"
|
||||
google-cloud-storage = "^2.13.0"
|
||||
groq = "^0.8.0"
|
||||
jinja2 = "^3.1.2"
|
||||
jsonschema = "*"
|
||||
litellm = "^1.17.9"
|
||||
|
||||
Reference in New Issue
Block a user