feat: image block for claude

This commit is contained in:
Nicholas Tindle
2025-01-21 17:03:11 +01:00
parent 23095f466a
commit 20b4a0e37f

View File

@@ -4,7 +4,7 @@ from abc import ABC
from enum import Enum, EnumMeta
from json import JSONDecodeError
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple, TypedDict
from pydantic import SecretStr
@@ -1173,3 +1173,328 @@ class AIListGeneratorBlock(AIBlockBase):
logger.debug(f"Retry prompt: {prompt}")
logger.debug("AIListGeneratorBlock.run completed")
class ClaudeWithImageBlock(Block):
"""Block for calling Claude API with support for images"""
class Input(BlockSchema):
class Image(TypedDict):
content_type: str # MIME type of the image
data: str # Base64 encoded image data
prompt: str = SchemaField(
description="The prompt to send to the language model.",
placeholder="Enter your prompt here...",
)
expected_format: dict[str, str] = SchemaField(
description="Expected format of the response. If provided, the response will be validated against this format. "
"The keys should be the expected fields in the response, and the values should be the description of the field.",
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.CLAUDE_3_5_SONNET,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
title="System Prompt",
default="",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[Message] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str | Image] = SchemaField(
advanced=False,
default={},
description="Values used to fill in the prompt. Images can be provided as base64 encoded data with MIME type.",
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
)
error: str = SchemaField(description="Error message if the API call failed.")
def __init__(self):
super().__init__(
id="bc043b3e-2926-4ed7-b276-735535d1a945",
description="Call Claude with support for images to generate formatted object based on the given prompt.",
categories={BlockCategory.AI},
input_schema=ClaudeWithImageBlock.Input,
output_schema=ClaudeWithImageBlock.Output,
test_input={
"model": LlmModel.CLAUDE_3_5_SONNET,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
"key2": "value2",
},
"prompt": "Describe this image",
"prompt_values": {
"image": {
"content_type": "image/jpeg",
"data": "base64_encoded_test_image",
}
},
},
test_credentials=TEST_CREDENTIALS,
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
test_mock={
"llm_call": lambda *args, **kwargs: (
json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
0,
0,
)
},
)
@staticmethod
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
prompt: list[dict],
max_tokens: int | None = None,
) -> tuple[str, int, int]:
"""
Call the Claude API with support for images in the messages.
Args:
credentials: API credentials for Claude
llm_model: The LLM model to use (must be Claude)
prompt: List of message dictionaries that can include image content
max_tokens: Maximum tokens to generate
Returns:
tuple containing:
- The text response
- Number of input tokens used
- Number of output tokens used
"""
if llm_model.metadata.provider != "anthropic":
raise ValueError("Only Claude models are supported for image processing")
# Extract system prompt if present
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
# Build messages array with content that can include images
messages = []
last_role = None
for p in prompt:
if p["role"] in ["user", "assistant"]:
message_content = []
# Handle text content
if isinstance(p["content"], str):
message_content.append({"type": "text", "text": p["content"]})
# Handle mixed content array with images
elif isinstance(p["content"], list):
message_content.extend(p["content"])
if p["role"] != last_role:
messages.append({"role": p["role"], "content": message_content})
last_role = p["role"]
else:
# Combine with previous message if same role
messages[-1]["content"].extend(message_content)
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
try:
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens or 8192,
)
if not resp.content:
raise ValueError("No content returned from Anthropic.")
return (
(
resp.content[0].name
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
else resp.content[0].text
),
resp.usage.input_tokens,
resp.usage.output_tokens,
)
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
raise ValueError(error_message)
def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
logger.debug(f"Calling Claude with input data: {input_data}")
# Start with any existing conversation history
prompt = [p.model_dump() for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines])
# Handle prompt values including images
content = []
values = input_data.prompt_values
# Add any images from prompt_values
for key, value in values.items():
if isinstance(value, dict) and "content_type" in value and "data" in value:
# This is an image
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": value["content_type"],
"data": value["data"],
},
}
)
# Add the text prompt
if input_data.prompt:
content.append(
{
"type": "text",
"text": fmt.format_string(
input_data.prompt,
{k: v for k, v in values.items() if isinstance(v, str)},
),
}
)
# Add system prompt if provided
if input_data.sys_prompt:
prompt.append(
{
"role": "system",
"content": fmt.format_string(input_data.sys_prompt, values),
}
)
# Add expected format if provided
if input_data.expected_format:
expected_format = [
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
]
format_prompt = ",\n ".join(expected_format)
sys_prompt = trim_prompt(
f"""
|Reply strictly only in the following JSON format:
|{{
| {format_prompt}
|}}
"""
)
prompt.append({"role": "system", "content": sys_prompt})
# Add the main prompt with images and text
prompt.append({"role": "user", "content": content})
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
try:
parsed = json.loads(resp)
if not isinstance(parsed, dict):
return {}, f"Expected a dictionary, but got {type(parsed)}"
if input_data.expected_format:
miss_keys = set(input_data.expected_format.keys()) - set(
parsed.keys()
)
if miss_keys:
return parsed, f"Missing keys: {miss_keys}"
return parsed, None
except JSONDecodeError as e:
return {}, f"JSON decode error: {e}"
logger.info(f"Claude request: {prompt}")
retry_prompt = ""
llm_model = input_data.model
for retry_count in range(input_data.retry):
try:
response_text, input_token, output_token = self.llm_call(
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
max_tokens=input_data.max_tokens,
)
self.merge_stats(
{
"input_token_count": input_token,
"output_token_count": output_token,
}
)
logger.info(f"Claude attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
parsed_dict, parsed_error = parse_response(response_text)
if not parsed_error:
yield "response", {
k: (
json.loads(v)
if isinstance(v, str)
and v.startswith("[")
and v.endswith("]")
else (", ".join(v) if isinstance(v, list) else v)
)
for k, v in parsed_dict.items()
}
return
else:
yield "response", {"response": response_text}
return
retry_prompt = trim_prompt(
f"""
|This is your previous error response:
|--
|{response_text}
|--
|
|And this is the error:
|--
|{parsed_error}
|--
"""
)
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.exception(f"Error calling Claude: {e}")
retry_prompt = f"Error calling Claude: {e}"
finally:
self.merge_stats(
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
)
raise RuntimeError(retry_prompt)