feat(Block): Add AdvancedLlmCallBlock (#7677)

* feat(Block): Add AdvancedLlmCallBlock

Adds a block for handling advanced LLM calls, enabling messages to be handled within the AutoGPT builder.

* fix linting
This commit is contained in:
Toran Bruce Richards
2024-08-03 00:17:03 +01:00
committed by GitHub
parent a21fd30fce
commit e0930ba39d

View File

@@ -1,6 +1,6 @@
import logging
from enum import Enum
from typing import NamedTuple
from typing import List, NamedTuple
import anthropic
import ollama
@@ -8,7 +8,7 @@ import openai
from groq import Groq
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from autogpt_server.data.model import BlockSecret, SecretField
from autogpt_server.data.model import BlockSecret, SchemaField, SecretField
from autogpt_server.util import json
logger = logging.getLogger(__name__)
@@ -409,3 +409,127 @@ class TextSummarizerBlock(Block):
).send(None)[
1
] # Get the first yielded value
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class Message(BlockSchema):
role: MessageRole
content: str
class AdvancedLlmCallBlock(Block):
class Input(BlockSchema):
messages: List[Message] = SchemaField(
description="List of messages in the conversation.", min_items=1
)
model: LlmModel = SchemaField(
default=LlmModel.GPT4_TURBO,
description="The language model to use for the conversation.",
)
api_key: BlockSecret = SecretField(
value="", description="API key for the chosen language model provider."
)
max_tokens: int | None = SchemaField(
default=None,
description="The maximum number of tokens to generate in the chat completion.",
ge=1,
)
class Output(BlockSchema):
response: str = SchemaField(
description="The model's response to the conversation."
)
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
id="c3d4e5f6-g7h8-i9j0-k1l2-m3n4o5p6q7r8",
description="Advanced LLM call that takes a list of messages and sends them to the language model.",
categories={BlockCategory.LLM},
input_schema=AdvancedLlmCallBlock.Input,
output_schema=AdvancedLlmCallBlock.Output,
test_input={
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{
"role": "assistant",
"content": "The Los Angeles Dodgers won the World Series in 2020.",
},
{"role": "user", "content": "Where was it played?"},
],
"model": LlmModel.GPT4_TURBO,
"api_key": "test_api_key",
},
test_output=(
"response",
"The 2020 World Series was played at Globe Life Field in Arlington, Texas.",
),
test_mock={
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
},
)
@staticmethod
def llm_call(
api_key: str,
model: LlmModel,
messages: List[dict[str, str]],
max_tokens: int | None = None,
) -> str:
provider = model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response = openai.chat.completions.create(
model=model.value,
messages=messages, # type: ignore
max_tokens=max_tokens,
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
client = anthropic.Anthropic(api_key=api_key)
response = client.messages.create(
model=model.value, max_tokens=max_tokens or 4096, messages=messages # type: ignore
)
return response.content[0].text if response.content else ""
elif provider == "groq":
client = Groq(api_key=api_key)
response = client.chat.completions.create(
model=model.value,
messages=messages, # type: ignore
max_tokens=max_tokens,
)
return response.choices[0].message.content or ""
elif provider == "ollama":
response = ollama.chat(
model=model.value, messages=messages, stream=False # type: ignore
)
return response["message"]["content"]
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(self, input_data: Input) -> BlockOutput:
try:
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
)
messages = [message.model_dump() for message in input_data.messages]
response = self.llm_call(
api_key=api_key,
model=input_data.model,
messages=messages,
max_tokens=input_data.max_tokens,
)
yield "response", response
except Exception as e:
yield "error", f"Error calling LLM: {str(e)}"