feat(blocks): Add support for additional LLM providers to LLM Block (#7434)

This commit adds support for the following models:

```python
# OpenAI Models
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"

# Anthropic models
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"

# Groq models
LLAMA3_8B = "llama3-8b-8192"
LLAMA3_70B = "llama3-70b-8192"
MIXTRAL_8X7B = "mixtral-8x7b-32768"
GEMMA_7B = "gemma-7b-it"
GEMMA2_9B = "gemma2-9b-it"
```

Every model has been tested with a single LLM block and is confirmed to be working in that setup.
This commit is contained in:
Toran Bruce Richards
2024-07-15 14:30:56 +01:00
committed by GitHub
parent 450f120510
commit 93b6e0ee51
5 changed files with 132 additions and 51 deletions

View File

@@ -1,25 +1,69 @@
import logging
from enum import Enum
from typing import NamedTuple
import openai
import anthropic
from groq import Groq
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
from autogpt_server.util import json
logger = logging.getLogger(__name__)
LlmApiKeys = {
"openai": BlockFieldSecret("openai_api_key"),
"anthropic": BlockFieldSecret("anthropic_api_key"),
"groq": BlockFieldSecret("groq_api_key"),
}
class ModelMetadata(NamedTuple):
provider: str
context_window: int
class LlmModel(str, Enum):
openai_gpt4 = "gpt-4-turbo"
# OpenAI models
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# Groq models
LLAMA3_8B = "llama3-8b-8192"
LLAMA3_70B = "llama3-70b-8192"
MIXTRAL_8X7B = "mixtral-8x7b-32768"
GEMMA_7B = "gemma-7b-it"
GEMMA2_9B = "gemma2-9b-it"
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self]
MODEL_METADATA = {
LlmModel.GPT4O: ModelMetadata("openai", 128000),
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000),
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385),
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000),
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000),
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192),
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192),
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768),
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192),
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192),
}
class LlmCallBlock(Block):
class Input(BlockSchema):
prompt: str
api_key: BlockFieldSecret = BlockFieldSecret(key="openai_api_key")
model: LlmModel = LlmModel.GPT4_TURBO
api_key: BlockFieldSecret = BlockFieldSecret(value="")
sys_prompt: str = ""
expected_format: dict[str, str] = {}
model: LlmModel = LlmModel.openai_gpt4
retry: int = 3
class Output(BlockSchema):
@@ -32,7 +76,7 @@ class LlmCallBlock(Block):
input_schema=LlmCallBlock.Input,
output_schema=LlmCallBlock.Output,
test_input={
"model": "gpt-4-turbo",
"model": LlmModel.GPT4_TURBO,
"api_key": "fake-api",
"expected_format": {
"key1": "value1",
@@ -48,14 +92,38 @@ class LlmCallBlock(Block):
)
@staticmethod
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json: bool) -> str:
openai.api_key = api_key
response = openai.chat.completions.create(
model=model,
messages=prompt, # type: ignore
response_format={"type": "json_object"} if json else None,
)
return response.choices[0].message.content or ""
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str:
provider = model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response = openai.chat.completions.create(
model=model.value,
messages=prompt, # type: ignore
response_format={"type": "json_object"} if json_format else None, # type: ignore
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
sysprompt = "".join([p["content"] for p in prompt if p["role"] == "system"])
usrprompt = [p for p in prompt if p["role"] == "user"]
client = anthropic.Anthropic(api_key=api_key)
response = client.messages.create(
model=model.value,
max_tokens=4096,
system=sysprompt,
messages=usrprompt, # type: ignore
)
return response.content[0].text if response.content else ""
elif provider == "groq":
client = Groq(api_key=api_key)
response = client.chat.completions.create(
model=model.value,
messages=prompt, # type: ignore
response_format={"type": "json_object"} if json_format else None,
)
return response.choices[0].message.content or ""
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(self, input_data: Input) -> BlockOutput:
prompt = []
@@ -68,17 +136,15 @@ class LlmCallBlock(Block):
prompt.append({"role": "system", "content": input_data.sys_prompt})
if input_data.expected_format:
expected_format = [f'"{k}": "{v}"' for k, v in
input_data.expected_format.items()]
expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()]
format_prompt = ",\n ".join(expected_format)
sys_prompt = f"""
sys_prompt = trim_prompt(f"""
|Reply in json format:
|{{
| {format_prompt}
|}}
"""
prompt.append({"role": "system", "content": trim_prompt(sys_prompt)})
""")
prompt.append({"role": "system", "content": sys_prompt})
prompt.append({"role": "user", "content": input_data.prompt})
@@ -94,35 +160,42 @@ class LlmCallBlock(Block):
logger.warning(f"LLM request: {prompt}")
retry_prompt = ""
model = input_data.model
api_key = input_data.api_key.get() or LlmApiKeys[model.metadata.provider].get()
for retry_count in range(input_data.retry):
response_text = self.llm_call(
api_key=input_data.api_key.get(),
model=input_data.model,
prompt=prompt,
json=bool(input_data.expected_format),
)
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
try:
response_text = self.llm_call(
api_key=api_key,
model=model,
prompt=prompt,
json_format=bool(input_data.expected_format),
)
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
parsed_dict, parsed_error = parse_response(response_text)
if not parsed_error:
yield "response", {k: str(v) for k, v in parsed_dict.items()}
if input_data.expected_format:
parsed_dict, parsed_error = parse_response(response_text)
if not parsed_error:
yield "response", {k: str(v) for k, v in parsed_dict.items()}
return
else:
yield "response", {"response": response_text}
return
else:
yield "response", {"response": response_text}
return
retry_prompt = f"""
|This is your previous error response:
|--
|{response_text}
|--
|
|And this is the error:
|--
|{parsed_error}
|--
"""
prompt.append({"role": "user", "content": trim_prompt(retry_prompt)})
retry_prompt = trim_prompt(f"""
|This is your previous error response:
|--
|{response_text}
|--
|
|And this is the error:
|--
|{parsed_error}
|--
""")
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.error(f"Error calling LLM: {e}")
retry_prompt = f"Error calling LLM: {e}"
yield "error", retry_prompt

View File

@@ -16,8 +16,12 @@ BlockOutput = Generator[BlockData, None, None]
class BlockFieldSecret:
def __init__(self, value=None, key=None):
self._value = value or self.__get_secret(key)
def __init__(self, key=None, value=None):
if value is not None:
self._value = value
return
self._value = self.__get_secret(key)
if self._value is None:
raise ValueError(f"Secret {key} not found.")

View File

@@ -69,7 +69,9 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
"""Secrets for the server."""
openai_api_key: str = Field(default="no_key", description="OpenAI API key")
openai_api_key: str = Field(default="", description="OpenAI API key")
anthropic_api_key: str = Field(default="", description="Anthropic API key")
groq_api_key: str = Field(default="", description="Groq API key")
reddit_client_id: str = Field(default="", description="Reddit client ID")
reddit_client_secret: str = Field(default="", description="Reddit client secret")

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "agpt"
@@ -25,7 +25,7 @@ requests = "*"
sentry-sdk = "^1.40.4"
[package.extras]
benchmark = ["agbenchmark @ file:///Users/czerwinski/Projects/AutoGPT/benchmark"]
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
[package.source]
type = "directory"
@@ -329,7 +329,7 @@ watchdog = "4.0.0"
webdriver-manager = "^4.0.1"
[package.extras]
benchmark = ["agbenchmark @ file:///Users/czerwinski/Projects/AutoGPT/benchmark"]
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
[package.source]
type = "directory"
@@ -6218,4 +6218,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "672bed951372c3cc6ec295dbde92e417a839adc1e9b5ac7b377df7c9f828a382"
content-hash = "ac40cb89830fcc95bec5c0dfb0646503dd9bd3abc26b7258ed69403fd546bed5"

View File

@@ -32,6 +32,8 @@ pydantic-settings = "^2.3.4"
praw = "^7.7.1"
openai = "^1.35.7"
jsonref = "^1.1.0"
groq = "^0.8.0"
anthropic = "^0.25.1"
[tool.poetry.group.dev.dependencies]