mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(blocks): Add support for additional LLM providers to LLM Block (#7434)
This commit adds support for the following models: ```python # OpenAI Models GPT4O = "gpt-4o" GPT4_TURBO = "gpt-4-turbo" GPT3_5_TURBO = "gpt-3.5-turbo" # Anthropic models CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" # Groq models LLAMA3_8B = "llama3-8b-8192" LLAMA3_70B = "llama3-70b-8192" MIXTRAL_8X7B = "mixtral-8x7b-32768" GEMMA_7B = "gemma-7b-it" GEMMA2_9B = "gemma2-9b-it" ``` Every model has been tested with a single LLM block and is confirmed to be working in that setup.
This commit is contained in:
committed by
GitHub
parent
450f120510
commit
93b6e0ee51
@@ -1,25 +1,69 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import NamedTuple
|
||||
|
||||
import openai
|
||||
import anthropic
|
||||
from groq import Groq
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
|
||||
from autogpt_server.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LlmApiKeys = {
|
||||
"openai": BlockFieldSecret("openai_api_key"),
|
||||
"anthropic": BlockFieldSecret("anthropic_api_key"),
|
||||
"groq": BlockFieldSecret("groq_api_key"),
|
||||
}
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
openai_gpt4 = "gpt-4-turbo"
|
||||
# OpenAI models
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# Groq models
|
||||
LLAMA3_8B = "llama3-8b-8192"
|
||||
LLAMA3_70B = "llama3-70b-8192"
|
||||
MIXTRAL_8X7B = "mixtral-8x7b-32768"
|
||||
GEMMA_7B = "gemma-7b-it"
|
||||
GEMMA2_9B = "gemma2-9b-it"
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192),
|
||||
}
|
||||
|
||||
|
||||
class LlmCallBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
prompt: str
|
||||
api_key: BlockFieldSecret = BlockFieldSecret(key="openai_api_key")
|
||||
model: LlmModel = LlmModel.GPT4_TURBO
|
||||
api_key: BlockFieldSecret = BlockFieldSecret(value="")
|
||||
sys_prompt: str = ""
|
||||
expected_format: dict[str, str] = {}
|
||||
model: LlmModel = LlmModel.openai_gpt4
|
||||
retry: int = 3
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -32,7 +76,7 @@ class LlmCallBlock(Block):
|
||||
input_schema=LlmCallBlock.Input,
|
||||
output_schema=LlmCallBlock.Output,
|
||||
test_input={
|
||||
"model": "gpt-4-turbo",
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"api_key": "fake-api",
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -48,14 +92,38 @@ class LlmCallBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json: bool) -> str:
|
||||
openai.api_key = api_key
|
||||
response = openai.chat.completions.create(
|
||||
model=model,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"} if json else None,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str:
|
||||
provider = model.metadata.provider
|
||||
|
||||
if provider == "openai":
|
||||
openai.api_key = api_key
|
||||
response = openai.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"} if json_format else None, # type: ignore
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "anthropic":
|
||||
sysprompt = "".join([p["content"] for p in prompt if p["role"] == "system"])
|
||||
usrprompt = [p for p in prompt if p["role"] == "user"]
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
response = client.messages.create(
|
||||
model=model.value,
|
||||
max_tokens=4096,
|
||||
system=sysprompt,
|
||||
messages=usrprompt, # type: ignore
|
||||
)
|
||||
return response.content[0].text if response.content else ""
|
||||
elif provider == "groq":
|
||||
client = Groq(api_key=api_key)
|
||||
response = client.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"} if json_format else None,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
prompt = []
|
||||
@@ -68,17 +136,15 @@ class LlmCallBlock(Block):
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
if input_data.expected_format:
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in
|
||||
input_data.expected_format.items()]
|
||||
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()]
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = f"""
|
||||
sys_prompt = trim_prompt(f"""
|
||||
|Reply in json format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
prompt.append({"role": "system", "content": trim_prompt(sys_prompt)})
|
||||
""")
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
@@ -94,35 +160,42 @@ class LlmCallBlock(Block):
|
||||
|
||||
logger.warning(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
model = input_data.model
|
||||
api_key = input_data.api_key.get() or LlmApiKeys[model.metadata.provider].get()
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
response_text = self.llm_call(
|
||||
api_key=input_data.api_key.get(),
|
||||
model=input_data.model,
|
||||
prompt=prompt,
|
||||
json=bool(input_data.expected_format),
|
||||
)
|
||||
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
try:
|
||||
response_text = self.llm_call(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
json_format=bool(input_data.expected_format),
|
||||
)
|
||||
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||
return
|
||||
else:
|
||||
yield "response", {"response": response_text}
|
||||
return
|
||||
else:
|
||||
yield "response", {"response": response_text}
|
||||
return
|
||||
|
||||
retry_prompt = f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|--
|
||||
"""
|
||||
prompt.append({"role": "user", "content": trim_prompt(retry_prompt)})
|
||||
retry_prompt = trim_prompt(f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|--
|
||||
""")
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling LLM: {e}")
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
|
||||
yield "error", retry_prompt
|
||||
|
||||
@@ -16,8 +16,12 @@ BlockOutput = Generator[BlockData, None, None]
|
||||
|
||||
|
||||
class BlockFieldSecret:
|
||||
def __init__(self, value=None, key=None):
|
||||
self._value = value or self.__get_secret(key)
|
||||
def __init__(self, key=None, value=None):
|
||||
if value is not None:
|
||||
self._value = value
|
||||
return
|
||||
|
||||
self._value = self.__get_secret(key)
|
||||
if self._value is None:
|
||||
raise ValueError(f"Secret {key} not found.")
|
||||
|
||||
|
||||
@@ -69,7 +69,9 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
|
||||
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
"""Secrets for the server."""
|
||||
openai_api_key: str = Field(default="no_key", description="OpenAI API key")
|
||||
openai_api_key: str = Field(default="", description="OpenAI API key")
|
||||
anthropic_api_key: str = Field(default="", description="Anthropic API key")
|
||||
groq_api_key: str = Field(default="", description="Groq API key")
|
||||
|
||||
reddit_client_id: str = Field(default="", description="Reddit client ID")
|
||||
reddit_client_secret: str = Field(default="", description="Reddit client secret")
|
||||
|
||||
8
rnd/autogpt_server/poetry.lock
generated
8
rnd/autogpt_server/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "agpt"
|
||||
@@ -25,7 +25,7 @@ requests = "*"
|
||||
sentry-sdk = "^1.40.4"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///Users/czerwinski/Projects/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@@ -329,7 +329,7 @@ watchdog = "4.0.0"
|
||||
webdriver-manager = "^4.0.1"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///Users/czerwinski/Projects/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@@ -6218,4 +6218,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "672bed951372c3cc6ec295dbde92e417a839adc1e9b5ac7b377df7c9f828a382"
|
||||
content-hash = "ac40cb89830fcc95bec5c0dfb0646503dd9bd3abc26b7258ed69403fd546bed5"
|
||||
|
||||
@@ -32,6 +32,8 @@ pydantic-settings = "^2.3.4"
|
||||
praw = "^7.7.1"
|
||||
openai = "^1.35.7"
|
||||
jsonref = "^1.1.0"
|
||||
groq = "^0.8.0"
|
||||
anthropic = "^0.25.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
Reference in New Issue
Block a user