mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-02 10:55:14 -05:00
Compare commits
36 Commits
test/verif
...
claude-ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c212f2b59 | ||
|
|
ea2910c560 | ||
|
|
27d0f03db3 | ||
|
|
1cc8981799 | ||
|
|
cebbdde75e | ||
|
|
eddcc97814 | ||
|
|
c1e8451c85 | ||
|
|
643d1a9e3f | ||
|
|
a4fc0d6206 | ||
|
|
5bb43c31c5 | ||
|
|
96ffa64971 | ||
|
|
d86a41147b | ||
|
|
d3425cae46 | ||
|
|
7682cbbe6c | ||
|
|
80ee8c61c4 | ||
|
|
cba05365e9 | ||
|
|
5aadbfe98a | ||
|
|
3e0bcbc7e4 | ||
|
|
b8749f7590 | ||
|
|
3aafa53f3b | ||
|
|
20b4a0e37f | ||
|
|
23095f466a | ||
|
|
769c75e6ac | ||
|
|
11ef0486ff | ||
|
|
d72c93c037 | ||
|
|
841500f378 | ||
|
|
b052413ab4 | ||
|
|
d31167958c | ||
|
|
a1a52b9569 | ||
|
|
50ad4a34dd | ||
|
|
81c403e103 | ||
|
|
2bfaf4d80c | ||
|
|
31e49fb55c | ||
|
|
da88da9a17 | ||
|
|
fed426ff77 | ||
|
|
33390ff7fe |
@@ -4,6 +4,16 @@ from abc import ABC
|
|||||||
from enum import Enum, EnumMeta
|
from enum import Enum, EnumMeta
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import NodeExecutionStats
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.file import MediaFile, store_media_file
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from enum import _EnumMemberT
|
||||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
@@ -64,9 +74,43 @@ def AICredentialsField() -> AICredentials:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(str, Enum):
|
||||||
|
OPENAI = "openai"
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
|
GROQ = "groq"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
OPEN_ROUTER = "open_router"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCreator(str, Enum):
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
|
META = "meta"
|
||||||
|
GOOGLE = "google"
|
||||||
|
OPENAI = "openai"
|
||||||
|
MISTRAL = "mistral"
|
||||||
|
COHERE = "cohere"
|
||||||
|
DEEPSEEK = "deepseek"
|
||||||
|
PERPLEXITY = "perplexity"
|
||||||
|
QWEN = "qwen"
|
||||||
|
NOUS = "nous"
|
||||||
|
AMAZON = "amazon"
|
||||||
|
MICROSOFT = "microsoft"
|
||||||
|
GRYPHE = "gryphe"
|
||||||
|
EVA = "eva"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCapabilities(NamedTuple):
|
||||||
|
supports_images: bool = False
|
||||||
|
supports_functions: bool = False
|
||||||
|
supports_vision: bool = False
|
||||||
|
is_local: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadata(NamedTuple):
|
class ModelMetadata(NamedTuple):
|
||||||
provider: str
|
provider: ModelProvider
|
||||||
|
creator: ModelCreator
|
||||||
context_window: int
|
context_window: int
|
||||||
|
capabilities: ModelCapabilities = ModelCapabilities()
|
||||||
max_output_tokens: int | None
|
max_output_tokens: int | None
|
||||||
|
|
||||||
|
|
||||||
@@ -154,68 +198,114 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
|
|
||||||
|
|
||||||
MODEL_METADATA = {
|
MODEL_METADATA = {
|
||||||
# https://platform.openai.com/docs/models
|
|
||||||
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
|
|
||||||
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
|
|
||||||
LlmModel.O1_PREVIEW: ModelMetadata(
|
LlmModel.O1_PREVIEW: ModelMetadata(
|
||||||
"openai", 128000, 32768
|
ModelProvider.OPENAI,
|
||||||
), # o1-preview-2024-09-12
|
ModelCreator.OPENAI,
|
||||||
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
|
32000,
|
||||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
ModelCapabilities(supports_images=True),
|
||||||
"openai", 128000, 16384
|
),
|
||||||
), # gpt-4o-mini-2024-07-18
|
LlmModel.O1_MINI: ModelMetadata(
|
||||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, 16384), # gpt-4o-2024-08-06
|
ModelProvider.OPENAI,
|
||||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
ModelCreator.OPENAI,
|
||||||
"openai", 128000, 4096
|
62000,
|
||||||
), # gpt-4-turbo-2024-04-09
|
ModelCapabilities(supports_images=True),
|
||||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
|
),
|
||||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
ModelProvider.OPENAI,
|
||||||
"anthropic", 200000, 8192
|
ModelCreator.OPENAI,
|
||||||
), # claude-3-5-sonnet-20241022
|
128000,
|
||||||
LlmModel.CLAUDE_3_5_HAIKU: ModelMetadata(
|
ModelCapabilities(supports_images=True),
|
||||||
"anthropic", 200000, 8192
|
),
|
||||||
), # claude-3-5-haiku-20241022
|
LlmModel.GPT4O: ModelMetadata(ModelProvider.OPENAI, ModelCreator.OPENAI, 128000),
|
||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||||
"anthropic", 200000, 4096
|
ModelProvider.OPENAI, ModelCreator.OPENAI, 128000
|
||||||
), # claude-3-haiku-20240307
|
),
|
||||||
# https://console.groq.com/docs/models
|
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, None),
|
ModelProvider.OPENAI, ModelCreator.OPENAI, 16385
|
||||||
LlmModel.LLAMA3_3_70B: ModelMetadata("groq", 128000, 32768),
|
),
|
||||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 128000, 8192),
|
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
|
||||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, None),
|
ModelProvider.ANTHROPIC,
|
||||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, None),
|
ModelCreator.ANTHROPIC,
|
||||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, None),
|
200000,
|
||||||
LlmModel.DEEPSEEK_LLAMA_70B: ModelMetadata("groq", 128000, None),
|
ModelCapabilities(supports_images=True),
|
||||||
# https://ollama.com/library
|
),
|
||||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata("ollama", 8192, None),
|
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192, None),
|
ModelProvider.ANTHROPIC,
|
||||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, None),
|
ModelCreator.ANTHROPIC,
|
||||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, None),
|
200000,
|
||||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
|
ModelCapabilities(supports_images=True),
|
||||||
# https://openrouter.ai/models
|
),
|
||||||
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
|
LlmModel.LLAMA3_8B: ModelMetadata(ModelProvider.GROQ, ModelCreator.META, 8192),
|
||||||
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
|
LlmModel.LLAMA3_70B: ModelMetadata(ModelProvider.GROQ, ModelCreator.META, 8192),
|
||||||
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
|
LlmModel.MIXTRAL_8X7B: ModelMetadata(
|
||||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
|
ModelProvider.GROQ, ModelCreator.MISTRAL, 32768
|
||||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata("open_router", 128000, 4096),
|
),
|
||||||
LlmModel.EVA_QWEN_2_5_32B: ModelMetadata("open_router", 16384, 4096),
|
LlmModel.GEMMA_7B: ModelMetadata(ModelProvider.GROQ, ModelCreator.GOOGLE, 8192),
|
||||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata("open_router", 64000, 2048),
|
LlmModel.GEMMA2_9B: ModelMetadata(ModelProvider.GROQ, ModelCreator.GOOGLE, 8192),
|
||||||
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
|
LlmModel.LLAMA3_1_405B: ModelMetadata(ModelProvider.GROQ, ModelCreator.META, 8192),
|
||||||
"open_router", 127072, 127072
|
# Limited to 16k during preview
|
||||||
|
LlmModel.LLAMA3_1_70B: ModelMetadata(ModelProvider.GROQ, ModelCreator.META, 131072),
|
||||||
|
LlmModel.LLAMA3_1_8B: ModelMetadata(ModelProvider.GROQ, ModelCreator.META, 131072),
|
||||||
|
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
||||||
|
ModelProvider.OLLAMA, ModelCreator.META, 8192, ModelCapabilities(is_local=True)
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
||||||
|
ModelProvider.OLLAMA, ModelCreator.META, 8192, ModelCapabilities(is_local=True)
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
||||||
|
ModelProvider.OLLAMA, ModelCreator.META, 8192, ModelCapabilities(is_local=True)
|
||||||
|
),
|
||||||
|
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
||||||
|
ModelProvider.OLLAMA, ModelCreator.META, 32768, ModelCapabilities(is_local=True)
|
||||||
|
),
|
||||||
|
LlmModel.GEMINI_FLASH_1_5_8B: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.GOOGLE, 8192
|
||||||
|
),
|
||||||
|
LlmModel.GROK_BETA: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.GOOGLE, 8192
|
||||||
|
),
|
||||||
|
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.MISTRAL, 4000
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.COHERE, 4000
|
||||||
|
),
|
||||||
|
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.COHERE, 4000
|
||||||
|
),
|
||||||
|
LlmModel.EVA_QWEN_2_5_32B: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.EVA, 4000
|
||||||
|
),
|
||||||
|
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.DEEPSEEK, 8192
|
||||||
|
),
|
||||||
|
LlmModel.PERPLEXITY_LLAMA_3_1_SONAR_LARGE_128K_ONLINE: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.PERPLEXITY, 8192
|
||||||
|
),
|
||||||
|
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.QWEN, 4000
|
||||||
),
|
),
|
||||||
LlmModel.QWEN_QWQ_32B_PREVIEW: ModelMetadata("open_router", 32768, 32768),
|
|
||||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||||
"open_router", 131000, 4096
|
ModelProvider.OPEN_ROUTER, ModelCreator.NOUS, 4000
|
||||||
),
|
),
|
||||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||||
"open_router", 12288, 12288
|
ModelProvider.OPEN_ROUTER, ModelCreator.NOUS, 4000
|
||||||
|
),
|
||||||
|
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.AMAZON, 4000
|
||||||
|
),
|
||||||
|
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.AMAZON, 4000
|
||||||
|
),
|
||||||
|
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.AMAZON, 4000
|
||||||
|
),
|
||||||
|
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.MICROSOFT, 4000
|
||||||
|
),
|
||||||
|
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||||
|
ModelProvider.OPEN_ROUTER, ModelCreator.GRYPHE, 4000
|
||||||
),
|
),
|
||||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata("open_router", 300000, 5120),
|
|
||||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata("open_router", 128000, 5120),
|
|
||||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata("open_router", 300000, 5120),
|
|
||||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata("open_router", 65536, 4096),
|
|
||||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata("open_router", 4096, 4096),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for model in LlmModel:
|
for model in LlmModel:
|
||||||
@@ -518,6 +608,11 @@ def llm_call(
|
|||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageWithMedia(Message):
|
||||||
|
role: MessageRole
|
||||||
|
content: str | MediaFile
|
||||||
|
|
||||||
|
|
||||||
class AIBlockBase(Block, ABC):
|
class AIBlockBase(Block, ABC):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -540,7 +635,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|||||||
)
|
)
|
||||||
model: LlmModel = SchemaField(
|
model: LlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
default=LlmModel.GPT4O,
|
default=LlmModel.CLAUDE_3_5_SONNET,
|
||||||
description="The language model to use for answering the prompt.",
|
description="The language model to use for answering the prompt.",
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
@@ -1367,3 +1462,335 @@ class AIListGeneratorBlock(AIBlockBase):
|
|||||||
logger.debug(f"Retry prompt: {prompt}")
|
logger.debug(f"Retry prompt: {prompt}")
|
||||||
|
|
||||||
logger.debug("AIListGeneratorBlock.run completed")
|
logger.debug("AIListGeneratorBlock.run completed")
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeWithImageBlock(Block):
|
||||||
|
"""Block for calling Claude API with support for images"""
|
||||||
|
|
||||||
|
class Input(BlockSchema):
|
||||||
|
|
||||||
|
prompt: str = SchemaField(
|
||||||
|
description="The prompt to send to the language model.",
|
||||||
|
placeholder="Enter your prompt here...",
|
||||||
|
)
|
||||||
|
expected_format: dict[str, str] = SchemaField(
|
||||||
|
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||||
|
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||||
|
)
|
||||||
|
model: LlmModel = SchemaField(
|
||||||
|
title="LLM Model",
|
||||||
|
default=LlmModel.CLAUDE_3_5_SONNET,
|
||||||
|
description="The language model to use for the conversation.",
|
||||||
|
)
|
||||||
|
credentials: AICredentials = AICredentialsField()
|
||||||
|
sys_prompt: str = SchemaField(
|
||||||
|
title="System Prompt",
|
||||||
|
default="",
|
||||||
|
description="The system prompt to provide additional context to the model.",
|
||||||
|
)
|
||||||
|
conversation_history: list[MessageWithMedia] = SchemaField(
|
||||||
|
default=[],
|
||||||
|
description="The conversation history to provide context for the prompt.",
|
||||||
|
)
|
||||||
|
retry: int = SchemaField(
|
||||||
|
title="Retry Count",
|
||||||
|
default=3,
|
||||||
|
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||||
|
)
|
||||||
|
prompt_values: dict[str, str | MediaFile] = SchemaField(
|
||||||
|
advanced=False,
|
||||||
|
default={},
|
||||||
|
description="Values used to fill in the prompt. Images can be provided as base64 encoded data with MIME type.",
|
||||||
|
)
|
||||||
|
max_tokens: int | None = SchemaField(
|
||||||
|
advanced=True,
|
||||||
|
default=None,
|
||||||
|
description="The maximum number of tokens to generate in the chat completion.",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
response: dict[str, Any] = SchemaField(
|
||||||
|
description="The response object generated by the language model."
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if the API call failed.")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="bc043b3e-2926-4ed7-b276-735535d1a945",
|
||||||
|
description="Call Claude with support for images to generate formatted object based on the given prompt.",
|
||||||
|
categories={BlockCategory.AI},
|
||||||
|
input_schema=ClaudeWithImageBlock.Input,
|
||||||
|
output_schema=ClaudeWithImageBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"model": LlmModel.CLAUDE_3_5_SONNET,
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"expected_format": {
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
"prompt": "Describe this image",
|
||||||
|
"prompt_values": {
|
||||||
|
"image": {
|
||||||
|
"data": "",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||||
|
test_mock={
|
||||||
|
"llm_call": lambda *args, **kwargs: (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"key1": "key1Value",
|
||||||
|
"key2": "key2Value",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llm_call(
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
llm_model: LlmModel,
|
||||||
|
prompt: list[dict],
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> tuple[str, int, int]:
|
||||||
|
"""
|
||||||
|
Call the Claude API with support for images in the messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
credentials: API credentials for Claude
|
||||||
|
llm_model: The LLM model to use (must be Claude)
|
||||||
|
prompt: List of message dictionaries that can include image content
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple containing:
|
||||||
|
- The text response
|
||||||
|
- Number of input tokens used
|
||||||
|
- Number of output tokens used
|
||||||
|
"""
|
||||||
|
if llm_model.metadata.provider != "anthropic":
|
||||||
|
raise ValueError("Only Claude models are supported for image processing")
|
||||||
|
|
||||||
|
# Extract system prompt if present
|
||||||
|
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||||
|
sysprompt = " ".join(system_messages)
|
||||||
|
|
||||||
|
# Build messages array with content that can include images
|
||||||
|
messages = []
|
||||||
|
last_role = None
|
||||||
|
|
||||||
|
for p in prompt:
|
||||||
|
if p["role"] in ["user", "assistant"]:
|
||||||
|
message_content = []
|
||||||
|
|
||||||
|
# Handle text content
|
||||||
|
if isinstance(p["content"], str):
|
||||||
|
message_content.append({"type": "text", "text": p["content"]})
|
||||||
|
# Handle mixed content array with images
|
||||||
|
elif isinstance(p["content"], list):
|
||||||
|
message_content.extend(p["content"])
|
||||||
|
|
||||||
|
if p["role"] != last_role:
|
||||||
|
messages.append({"role": p["role"], "content": message_content})
|
||||||
|
last_role = p["role"]
|
||||||
|
else:
|
||||||
|
# Combine with previous message if same role
|
||||||
|
messages[-1]["content"].extend(message_content)
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = client.messages.create(
|
||||||
|
model=llm_model.value,
|
||||||
|
system=sysprompt,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens or 8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not resp.content:
|
||||||
|
raise ValueError("No content returned from Anthropic.")
|
||||||
|
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
resp.content[0].name
|
||||||
|
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||||
|
else resp.content[0].text
|
||||||
|
),
|
||||||
|
resp.usage.input_tokens,
|
||||||
|
resp.usage.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
except anthropic.APIError as e:
|
||||||
|
error_message = f"Anthropic API error: {str(e)}"
|
||||||
|
logger.error(error_message)
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
graph_exec_id: str,
|
||||||
|
credentials: APIKeyCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
logger.debug(f"Calling Claude with input data: {input_data}")
|
||||||
|
|
||||||
|
# Start with any existing conversation history
|
||||||
|
prompt = [p.model_dump() for p in input_data.conversation_history]
|
||||||
|
|
||||||
|
def trim_prompt(s: str) -> str:
|
||||||
|
lines = s.strip().split("\n")
|
||||||
|
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||||
|
|
||||||
|
# Handle prompt values including images
|
||||||
|
content = []
|
||||||
|
values: dict[str, str | MediaFile] = input_data.prompt_values
|
||||||
|
|
||||||
|
# Add any images from prompt_values
|
||||||
|
for key, value in values.items():
|
||||||
|
# This is an image
|
||||||
|
if isinstance(value, MediaFile):
|
||||||
|
# media file is a base64 encoded image
|
||||||
|
# read the media file
|
||||||
|
media_path = store_media_file(
|
||||||
|
graph_exec_id=graph_exec_id, file=value, return_content=True
|
||||||
|
)
|
||||||
|
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_path.split(";")[0].split(":")[1],
|
||||||
|
"data": media_path,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the text prompt
|
||||||
|
if input_data.prompt:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": fmt.format_string(
|
||||||
|
input_data.prompt,
|
||||||
|
{k: v for k, v in values.items() if isinstance(v, str)},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add system prompt if provided
|
||||||
|
if input_data.sys_prompt:
|
||||||
|
prompt.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": fmt.format_string(input_data.sys_prompt, values),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add expected format if provided
|
||||||
|
if input_data.expected_format:
|
||||||
|
expected_format = [
|
||||||
|
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||||
|
]
|
||||||
|
format_prompt = ",\n ".join(expected_format)
|
||||||
|
sys_prompt = trim_prompt(
|
||||||
|
f"""
|
||||||
|
|Reply strictly only in the following JSON format:
|
||||||
|
|{{
|
||||||
|
| {format_prompt}
|
||||||
|
|}}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
prompt.append({"role": "system", "content": sys_prompt})
|
||||||
|
|
||||||
|
# Add the main prompt with images and text
|
||||||
|
prompt.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(resp)
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
return {}, f"Expected a dictionary, but got {type(parsed)}"
|
||||||
|
if input_data.expected_format:
|
||||||
|
miss_keys = set(input_data.expected_format.keys()) - set(
|
||||||
|
parsed.keys()
|
||||||
|
)
|
||||||
|
if miss_keys:
|
||||||
|
return parsed, f"Missing keys: {miss_keys}"
|
||||||
|
return parsed, None
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
return {}, f"JSON decode error: {e}"
|
||||||
|
|
||||||
|
logger.info(f"Claude request: {prompt}")
|
||||||
|
retry_prompt = ""
|
||||||
|
llm_model = input_data.model
|
||||||
|
|
||||||
|
for retry_count in range(input_data.retry):
|
||||||
|
try:
|
||||||
|
response_text, input_token, output_token = self.llm_call(
|
||||||
|
credentials=credentials,
|
||||||
|
llm_model=llm_model,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=input_data.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merge_stats(
|
||||||
|
{
|
||||||
|
"input_token_count": input_token,
|
||||||
|
"output_token_count": output_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Claude attempt-{retry_count} response: {response_text}")
|
||||||
|
|
||||||
|
if input_data.expected_format:
|
||||||
|
parsed_dict, parsed_error = parse_response(response_text)
|
||||||
|
if not parsed_error:
|
||||||
|
yield "response", {
|
||||||
|
k: (
|
||||||
|
json.loads(v)
|
||||||
|
if isinstance(v, str)
|
||||||
|
and v.startswith("[")
|
||||||
|
and v.endswith("]")
|
||||||
|
else (", ".join(v) if isinstance(v, list) else v)
|
||||||
|
)
|
||||||
|
for k, v in parsed_dict.items()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
yield "response", {"response": response_text}
|
||||||
|
return
|
||||||
|
|
||||||
|
retry_prompt = trim_prompt(
|
||||||
|
f"""
|
||||||
|
|This is your previous error response:
|
||||||
|
|--
|
||||||
|
|{response_text}
|
||||||
|
|--
|
||||||
|
|
|
||||||
|
|And this is the error:
|
||||||
|
|--
|
||||||
|
|{parsed_error}
|
||||||
|
|--
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
prompt.append({"role": "user", "content": retry_prompt})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error calling Claude: {e}")
|
||||||
|
retry_prompt = f"Error calling Claude: {e}"
|
||||||
|
finally:
|
||||||
|
self.merge_stats(
|
||||||
|
{
|
||||||
|
"llm_call_count": retry_count + 1,
|
||||||
|
"llm_retry_count": retry_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(retry_prompt)
|
||||||
|
|||||||
Reference in New Issue
Block a user