mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat: image block for claude
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple, TypedDict
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -1173,3 +1173,328 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"Retry prompt: {prompt}")
|
||||
|
||||
logger.debug("AIListGeneratorBlock.run completed")
|
||||
|
||||
|
||||
class ClaudeWithImageBlock(Block):
|
||||
"""Block for calling Claude API with support for images"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
class Image(TypedDict):
|
||||
content_type: str # MIME type of the image
|
||||
data: str # Base64 encoded image data
|
||||
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
expected_format: dict[str, str] = SchemaField(
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.CLAUDE_3_5_SONNET,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[Message] = SchemaField(
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str | Image] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. Images can be provided as base64 encoded data with MIME type.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bc043b3e-2926-4ed7-b276-735535d1a945",
|
||||
description="Call Claude with support for images to generate formatted object based on the given prompt.",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=ClaudeWithImageBlock.Input,
|
||||
output_schema=ClaudeWithImageBlock.Output,
|
||||
test_input={
|
||||
"model": LlmModel.CLAUDE_3_5_SONNET,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
"prompt": "Describe this image",
|
||||
"prompt_values": {
|
||||
"image": {
|
||||
"content_type": "image/jpeg",
|
||||
"data": "base64_encoded_test_image",
|
||||
}
|
||||
},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: (
|
||||
json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
Call the Claude API with support for images in the messages.
|
||||
|
||||
Args:
|
||||
credentials: API credentials for Claude
|
||||
llm_model: The LLM model to use (must be Claude)
|
||||
prompt: List of message dictionaries that can include image content
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- The text response
|
||||
- Number of input tokens used
|
||||
- Number of output tokens used
|
||||
"""
|
||||
if llm_model.metadata.provider != "anthropic":
|
||||
raise ValueError("Only Claude models are supported for image processing")
|
||||
|
||||
# Extract system prompt if present
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
|
||||
# Build messages array with content that can include images
|
||||
messages = []
|
||||
last_role = None
|
||||
|
||||
for p in prompt:
|
||||
if p["role"] in ["user", "assistant"]:
|
||||
message_content = []
|
||||
|
||||
# Handle text content
|
||||
if isinstance(p["content"], str):
|
||||
message_content.append({"type": "text", "text": p["content"]})
|
||||
# Handle mixed content array with images
|
||||
elif isinstance(p["content"], list):
|
||||
message_content.extend(p["content"])
|
||||
|
||||
if p["role"] != last_role:
|
||||
messages.append({"role": p["role"], "content": message_content})
|
||||
last_role = p["role"]
|
||||
else:
|
||||
# Combine with previous message if same role
|
||||
messages[-1]["content"].extend(message_content)
|
||||
|
||||
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
try:
|
||||
resp = client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens or 8192,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
return (
|
||||
(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else resp.content[0].text
|
||||
),
|
||||
resp.usage.input_tokens,
|
||||
resp.usage.output_tokens,
|
||||
)
|
||||
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Calling Claude with input data: {input_data}")
|
||||
|
||||
# Start with any existing conversation history
|
||||
prompt = [p.model_dump() for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
# Handle prompt values including images
|
||||
content = []
|
||||
values = input_data.prompt_values
|
||||
|
||||
# Add any images from prompt_values
|
||||
for key, value in values.items():
|
||||
if isinstance(value, dict) and "content_type" in value and "data" in value:
|
||||
# This is an image
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": value["content_type"],
|
||||
"data": value["data"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add the text prompt
|
||||
if input_data.prompt:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": fmt.format_string(
|
||||
input_data.prompt,
|
||||
{k: v for k, v in values.items() if isinstance(v, str)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Add system prompt if provided
|
||||
if input_data.sys_prompt:
|
||||
prompt.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": fmt.format_string(input_data.sys_prompt, values),
|
||||
}
|
||||
)
|
||||
|
||||
# Add expected format if provided
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
|
||||
# Add the main prompt with images and text
|
||||
prompt.append({"role": "user", "content": content})
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
if not isinstance(parsed, dict):
|
||||
return {}, f"Expected a dictionary, but got {type(parsed)}"
|
||||
if input_data.expected_format:
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(
|
||||
parsed.keys()
|
||||
)
|
||||
if miss_keys:
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
except JSONDecodeError as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
|
||||
logger.info(f"Claude request: {prompt}")
|
||||
retry_prompt = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
try:
|
||||
response_text, input_token, output_token = self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
|
||||
self.merge_stats(
|
||||
{
|
||||
"input_token_count": input_token,
|
||||
"output_token_count": output_token,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Claude attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {
|
||||
k: (
|
||||
json.loads(v)
|
||||
if isinstance(v, str)
|
||||
and v.startswith("[")
|
||||
and v.endswith("]")
|
||||
else (", ".join(v) if isinstance(v, list) else v)
|
||||
)
|
||||
for k, v in parsed_dict.items()
|
||||
}
|
||||
return
|
||||
else:
|
||||
yield "response", {"response": response_text}
|
||||
return
|
||||
|
||||
retry_prompt = trim_prompt(
|
||||
f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling Claude: {e}")
|
||||
retry_prompt = f"Error calling Claude: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
{
|
||||
"llm_call_count": retry_count + 1,
|
||||
"llm_retry_count": retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
raise RuntimeError(retry_prompt)
|
||||
|
||||
Reference in New Issue
Block a user