diff --git a/rnd/autogpt_server/autogpt_server/blocks/llm.py b/rnd/autogpt_server/autogpt_server/blocks/llm.py index 022fd14827..a755d70a11 100644 --- a/rnd/autogpt_server/autogpt_server/blocks/llm.py +++ b/rnd/autogpt_server/autogpt_server/blocks/llm.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import NamedTuple +from typing import List, NamedTuple import anthropic import ollama @@ -8,7 +8,7 @@ import openai from groq import Groq from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema -from autogpt_server.data.model import BlockSecret, SecretField +from autogpt_server.data.model import BlockSecret, SchemaField, SecretField from autogpt_server.util import json logger = logging.getLogger(__name__) @@ -409,3 +409,127 @@ class TextSummarizerBlock(Block): ).send(None)[ 1 ] # Get the first yielded value + + +class MessageRole(str, Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + + +class Message(BlockSchema): + role: MessageRole + content: str + + +class AdvancedLlmCallBlock(Block): + class Input(BlockSchema): + messages: List[Message] = SchemaField( + description="List of messages in the conversation.", min_items=1 + ) + model: LlmModel = SchemaField( + default=LlmModel.GPT4_TURBO, + description="The language model to use for the conversation.", + ) + api_key: BlockSecret = SecretField( + value="", description="API key for the chosen language model provider." + ) + max_tokens: int | None = SchemaField( + default=None, + description="The maximum number of tokens to generate in the chat completion.", + ge=1, + ) + + class Output(BlockSchema): + response: str = SchemaField( + description="The model's response to the conversation." + ) + error: str = SchemaField(description="Error message if the API call failed.") + + def __init__(self): + super().__init__( + id="c3d4e5f6-g7h8-i9j0-k1l2-m3n4o5p6q7r8", + description="Advanced LLM call that takes a list of messages and sends them to the language model.", + categories={BlockCategory.LLM}, + input_schema=AdvancedLlmCallBlock.Input, + output_schema=AdvancedLlmCallBlock.Output, + test_input={ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": LlmModel.GPT4_TURBO, + "api_key": "test_api_key", + }, + test_output=( + "response", + "The 2020 World Series was played at Globe Life Field in Arlington, Texas.", + ), + test_mock={ + "llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas." + }, + ) + + @staticmethod + def llm_call( + api_key: str, + model: LlmModel, + messages: List[dict[str, str]], + max_tokens: int | None = None, + ) -> str: + provider = model.metadata.provider + + if provider == "openai": + openai.api_key = api_key + response = openai.chat.completions.create( + model=model.value, + messages=messages, # type: ignore + max_tokens=max_tokens, + ) + return response.choices[0].message.content or "" + elif provider == "anthropic": + client = anthropic.Anthropic(api_key=api_key) + response = client.messages.create( + model=model.value, max_tokens=max_tokens or 4096, messages=messages # type: ignore + ) + return response.content[0].text if response.content else "" + elif provider == "groq": + client = Groq(api_key=api_key) + response = client.chat.completions.create( + model=model.value, + messages=messages, # type: ignore + max_tokens=max_tokens, + ) + return response.choices[0].message.content or "" + elif provider == "ollama": + response = ollama.chat( + model=model.value, messages=messages, stream=False # type: ignore + ) + return response["message"]["content"] + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + + def run(self, input_data: Input) -> BlockOutput: + try: + api_key = ( + input_data.api_key.get_secret_value() + or LlmApiKeys[input_data.model.metadata.provider].get_secret_value() + ) + + messages = [message.model_dump() for message in input_data.messages] + + response = self.llm_call( + api_key=api_key, + model=input_data.model, + messages=messages, + max_tokens=input_data.max_tokens, + ) + + yield "response", response + except Exception as e: + yield "error", f"Error calling LLM: {str(e)}"