fix: address CodeRabbit review feedback

- Fix Union[Literal] -> single Literal for credentials type
- Remove dead URL branch; GPT-image models always return b64_json
- Move inner imports (BytesIO, base64) to top level
- Replace # type: ignore[arg-type] with cast() for size Literal
- Add MATCH_INPUT_IMAGE -> auto mapping for OpenAI size
- Guard data-URI split with validation
- Fix customizer: dispatch to images.edit when reference image provided,
  images.generate when no reference image
This commit is contained in:
Toran Bruce Richards
2026-04-22 20:38:46 +00:00
parent 23b5e1272e
commit 722a8ad534
3 changed files with 32 additions and 50 deletions

View File

@@ -32,7 +32,7 @@ class ImageCustomizerModel(str, Enum):
NANO_BANANA_PRO = "google/nano-banana-pro"
NANO_BANANA_2 = "google/nano-banana-2"
GPT_IMAGE_1 = "gpt-image-1"
GPT_IMAGE_1_5 = "gpt-image-1-5"
GPT_IMAGE_1_5 = "gpt-image-1.5"
GPT_IMAGE_2 = "gpt-image-2"
GPT_IMAGE_1_MINI = "gpt-image-1-mini"

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Literal, Union
from typing import Literal, cast
import openai
from pydantic import SecretStr
@@ -126,7 +126,7 @@ class ImageGenModel(str, Enum):
NANO_BANANA_PRO = "Nano Banana Pro"
NANO_BANANA_2 = "Nano Banana 2"
GPT_IMAGE_1 = "gpt-image-1"
GPT_IMAGE_1_5 = "gpt-image-1-5"
GPT_IMAGE_1_5 = "gpt-image-1.5"
GPT_IMAGE_2 = "gpt-image-2"
GPT_IMAGE_1_MINI = "gpt-image-1-mini"
@@ -134,7 +134,7 @@ class ImageGenModel(str, Enum):
class AIImageGeneratorBlock(Block):
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput[
Union[Literal[ProviderName.REPLICATE], Literal[ProviderName.OPENAI]],
Literal[ProviderName.REPLICATE, ProviderName.OPENAI],
Literal["api_key"],
] = CredentialsField(
description="Enter your Replicate or OpenAI API key to access the image generation API.",
@@ -188,12 +188,10 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
# Test output is a data URI since we now store images
lambda x: x.startswith("data:image/"),
),
],
test_mock={
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: (
"data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
),
@@ -207,13 +205,9 @@ class AIImageGeneratorBlock(Block):
self, credentials: APIKeyCredentials, model_name: str, input_params: dict
):
try:
# Initialize Replicate client
client = ReplicateClient(api_token=credentials.api_key.get_secret_value())
# Run the model with input parameters
output = await client.async_run(model_name, input=input_params, wait=False)
# Process output
if isinstance(output, list) and len(output) > 0:
if isinstance(output[0], FileOutput):
result_url = output[0].url
@@ -236,39 +230,35 @@ class AIImageGeneratorBlock(Block):
async def _generate_with_openai(
self, input_data: Input, credentials: APIKeyCredentials
) -> str:
client = openai.AsyncOpenAI(
api_key=credentials.api_key.get_secret_value()
)
client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
size = SIZE_TO_OPENAI.get(input_data.size, "1024x1024")
size_literal = cast(
Literal["1024x1024", "1536x1024", "1024x1536"], size
)
response = await client.images.generate(
model=input_data.model.value,
prompt=input_data.prompt,
n=1,
size=size, # type: ignore[arg-type]
size=size_literal,
quality="auto",
)
if response.data and response.data[0].url:
return response.data[0].url
if response.data and response.data[0].b64_json:
return f"data:image/png;base64,{response.data[0].b64_json}"
raise RuntimeError("OpenAI image generation returned empty result")
if not response.data or not response.data[0].b64_json:
raise RuntimeError("OpenAI image generation returned empty result")
return f"data:image/png;base64,{response.data[0].b64_json}"
async def generate_image(self, input_data: Input, credentials: APIKeyCredentials):
try:
# Route to OpenAI for GPT-image models
if input_data.model.value.startswith("gpt-image"):
return await self._generate_with_openai(input_data, credentials)
# Handle style-based prompt modification for models without native style support
modified_prompt = input_data.prompt
if input_data.model not in [ImageGenModel.RECRAFT]:
style_prefix = self._style_to_prompt_prefix(input_data.style)
modified_prompt = f"{style_prefix} {modified_prompt}".strip()
if input_data.model == ImageGenModel.SD3_5:
# Use Stable Diffusion 3.5 with aspect ratio
input_params = {
"prompt": modified_prompt,
"aspect_ratio": SIZE_TO_SD_RATIO[input_data.size],
@@ -285,14 +275,13 @@ class AIImageGeneratorBlock(Block):
return output
elif input_data.model == ImageGenModel.FLUX:
# Use Flux-specific dimensions with 'jpg' format to avoid ReplicateError
width, height = SIZE_TO_FLUX_DIMENSIONS[input_data.size]
input_params = {
"prompt": modified_prompt,
"width": width,
"height": height,
"aspect_ratio": SIZE_TO_FLUX_RATIO[input_data.size],
"output_format": "jpg", # Set to jpg for Flux models
"output_format": "jpg",
"output_quality": 90,
}
output = await self._run_client(
@@ -330,7 +319,6 @@ class AIImageGeneratorBlock(Block):
ImageGenModel.NANO_BANANA_PRO,
ImageGenModel.NANO_BANANA_2,
):
# Use Nano Banana models (Google Gemini image variants)
model_map = {
ImageGenModel.NANO_BANANA_PRO: "google/nano-banana-pro",
ImageGenModel.NANO_BANANA_2: "google/nano-banana-2",
@@ -351,9 +339,6 @@ class AIImageGeneratorBlock(Block):
raise RuntimeError(f"Failed to generate image: {str(e)}")
def _style_to_prompt_prefix(self, style: ImageStyle) -> str:
"""
Convert a style enum to a prompt prefix for models without native style support.
"""
if style == ImageStyle.ANY:
return ""
@@ -392,7 +377,6 @@ class AIImageGeneratorBlock(Block):
try:
url = await self.generate_image(input_data, credentials)
if url:
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
@@ -402,11 +386,9 @@ class AIImageGeneratorBlock(Block):
else:
yield "error", "Image generation returned an empty result."
except Exception as e:
# Capture and return only the message of the exception, avoiding serialization of non-serializable objects
yield "error", str(e)
# Test credentials stay the same
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="replicate",

View File

@@ -1,5 +1,7 @@
from enum import Enum
from typing import Literal, Optional, Union
from io import BytesIO
import base64
from typing import Literal, Optional, cast
import openai
from pydantic import SecretStr
@@ -45,7 +47,7 @@ class ImageEditorModel(str, Enum):
NANO_BANANA_PRO = "Nano Banana Pro"
NANO_BANANA_2 = "Nano Banana 2"
GPT_IMAGE_1 = "gpt-image-1"
GPT_IMAGE_1_5 = "gpt-image-1-5"
GPT_IMAGE_1_5 = "gpt-image-1.5"
GPT_IMAGE_2 = "gpt-image-2"
GPT_IMAGE_1_MINI = "gpt-image-1-mini"
@@ -82,6 +84,7 @@ class AspectRatio(str, Enum):
ASPECT_TO_OPENAI_SIZE = {
AspectRatio.MATCH_INPUT_IMAGE: "auto",
AspectRatio.ASPECT_1_1: "1024x1024",
AspectRatio.ASPECT_16_9: "1536x1024",
AspectRatio.ASPECT_9_16: "1024x1536",
@@ -101,7 +104,7 @@ ASPECT_TO_OPENAI_SIZE = {
class AIImageEditorBlock(Block):
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput[
Union[Literal[ProviderName.REPLICATE], Literal[ProviderName.OPENAI]],
Literal[ProviderName.REPLICATE, ProviderName.OPENAI],
Literal["api_key"],
] = CredentialsField(
description="Replicate or OpenAI API key with permissions for image editing models",
@@ -157,13 +160,11 @@ class AIImageEditorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
# Output will be a workspace ref or data URI depending on context
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAhKmMIQAAAABJRU5ErkJggg=="
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
),
},
test_credentials=TEST_CREDENTIALS,
@@ -185,7 +186,7 @@ class AIImageEditorBlock(Block):
await store_media_file(
file=input_data.input_image,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
return_format="for_external_api",
)
if input_data.input_image
else None
@@ -195,7 +196,6 @@ class AIImageEditorBlock(Block):
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
)
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
@@ -215,26 +215,28 @@ class AIImageEditorBlock(Block):
raise ValueError("OpenAI image editing requires an input image.")
client = openai.AsyncOpenAI(api_key=api_key.get_secret_value())
from io import BytesIO
import base64
header, encoded = str(input_image_b64).split(",", 1)
data_uri = str(input_image_b64)
if "," not in data_uri:
raise ValueError("Expected a data-URI for the input image.")
_, encoded = data_uri.split(",", 1)
image_bytes = BytesIO(base64.b64decode(encoded))
size = ASPECT_TO_OPENAI_SIZE.get(aspect_ratio, "1024x1024")
size_literal = cast(
Literal["1024x1024", "1536x1024", "1024x1536", "auto"], size
)
response = await client.images.edit(
model=model.value,
image=image_bytes,
prompt=prompt,
n=1,
size=size, # type: ignore[arg-type]
size=size_literal,
)
if response.data and response.data[0].url:
return MediaFileType(response.data[0].url)
if response.data and response.data[0].b64_json:
return MediaFileType(f"data:image/png;base64,{response.data[0].b64_json}")
raise ValueError("OpenAI image edit returned empty result")
if not response.data or not response.data[0].b64_json:
raise ValueError("OpenAI image edit returned empty result")
return MediaFileType(f"data:image/png;base64,{response.data[0].b64_json}")
async def run_model(
self,
@@ -247,7 +249,6 @@ class AIImageEditorBlock(Block):
user_id: str,
graph_exec_id: str,
) -> MediaFileType:
# Route to OpenAI for GPT-image models
if model.value.startswith("gpt-image"):
return await self._edit_with_openai(
api_key, model, prompt, input_image_b64, aspect_ratio
@@ -267,7 +268,6 @@ class AIImageEditorBlock(Block):
"output_format": "jpg",
"safety_filter_level": "block_only_high",
}
# NB API expects "image_input" as a list, unlike Flux's single "input_image"
if input_image_b64:
input_params["image_input"] = [input_image_b64]
else: