feat(backend): Integrate Ideogram for auto-generating created agent thumbnail (#9568)

### Changes 🏗️

Integrate Ideogram for auto-generating created agent thumbnail

![preview of the UI with a generated
image](https://github.com/user-attachments/assets/87fb4179-59e0-4109-b5aa-d45ebe9decf7)

**Note:** switching back to the "old" Replicate-based image generator is
possible by setting `USE_AGENT_IMAGE_GENERATION_V2=false`.

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
Zamil Majdy
2025-03-05 21:41:42 +07:00
committed by GitHub
parent b50ecbf7e9
commit c091a7be62
6 changed files with 154 additions and 55 deletions

View File

@@ -142,6 +142,16 @@ class IdeogramModelBlock(Block):
title="Color Palette Preset",
advanced=True,
)
custom_color_palette: Optional[list[str]] = SchemaField(
description=(
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
"palette. Only used if 'color_palette_name' is 'NONE'."
),
default=None,
title="Custom Color Palette",
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="Generated image URL")
@@ -164,6 +174,13 @@ class IdeogramModelBlock(Block):
"style_type": StyleType.AUTO,
"negative_prompt": None,
"color_palette_name": ColorPalettePreset.NONE,
"custom_color_palette": [
"#000030",
"#1C0C47",
"#9900FF",
"#4285F4",
"#FFFFFF",
],
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
@@ -173,7 +190,7 @@ class IdeogramModelBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
},
test_credentials=TEST_CREDENTIALS,
@@ -195,6 +212,7 @@ class IdeogramModelBlock(Block):
style_type=input_data.style_type.value,
negative_prompt=input_data.negative_prompt,
color_palette_name=input_data.color_palette_name.value,
custom_colors=input_data.custom_color_palette,
)
# Step 2: Upscale the image if requested
@@ -217,6 +235,7 @@ class IdeogramModelBlock(Block):
style_type: str,
negative_prompt: Optional[str],
color_palette_name: str,
custom_colors: Optional[list[str]],
):
url = "https://api.ideogram.ai/generate"
headers = {
@@ -241,7 +260,11 @@ class IdeogramModelBlock(Block):
data["image_request"]["negative_prompt"] = negative_prompt
if color_palette_name != "NONE":
data["image_request"]["color_palette"] = {"name": color_palette_name}
data["color_palette"] = {"name": color_palette_name}
elif custom_colors:
data["color_palette"] = {
"members": [{"color_hex": color} for color in custom_colors]
}
try:
response = requests.post(url, json=data, headers=headers)
@@ -267,9 +290,7 @@ class IdeogramModelBlock(Block):
response = requests.post(
url,
headers=headers,
data={
"image_request": "{}", # Empty JSON object
},
data={"image_request": "{}"},
files=files,
)

View File

@@ -393,7 +393,8 @@ async def get_graph_all_versions(
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
create_graph: CreateGraph,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphModel:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
@@ -401,10 +402,9 @@ async def create_new_graph(
graph = await graph_db.create_graph(graph, user_id=user_id)
# Create a library agent for the new graph
await library_db.create_library_agent(
graph.id,
graph.version,
user_id,
library_agent = await library_db.create_library_agent(graph, user_id)
_ = asyncio.create_task(
library_db.add_generated_agent_image(graph, library_agent.id)
)
graph = await on_graph_activate(

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Optional
@@ -7,14 +8,17 @@ import prisma.fields
import prisma.models
import prisma.types
import backend.data.graph
import backend.data.includes
import backend.server.model
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
async def list_library_agents(
@@ -168,17 +172,53 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
async def add_generated_agent_image(
graph: backend.data.graph.GraphModel,
library_agent_id: str,
) -> Optional[prisma.models.LibraryAgent]:
"""
Generates an image for the specified LibraryAgent and updates its record.
"""
user_id = graph.user_id
graph_id = graph.id
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{graph_id}.jpeg"
try:
if not (image_url := await store_media.check_media_exists(user_id, filename)):
# Generate agent image as JPEG
if config.use_agent_image_generation_v2:
image = await asyncio.to_thread(
store_image_gen.generate_agent_image_v2, graph=graph
)
else:
image = await store_image_gen.generate_agent_image(agent=graph)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(file=image, filename=filename)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
except Exception as e:
logger.warning(f"Error generating and uploading agent image: {e}")
return None
return await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent_id},
data={"imageUrl": image_url},
)
async def create_library_agent(
agent_id: str,
agent_version: int,
graph: backend.data.graph.GraphModel,
user_id: str,
) -> prisma.models.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table).
Args:
agent_id: The ID of the agent to add.
agent_version: The version of the agent to add.
agent: The agent/Graph to add to the library.
user_id: The user to whom the agent will be added.
Returns:
@@ -189,52 +229,19 @@ async def create_library_agent(
DatabaseError: If there's an error during creation or if image generation fails.
"""
logger.info(
f"Creating library agent for graph #{agent_id} v{agent_version}; "
f"Creating library agent for graph #{graph.id} v{graph.version}; "
f"user #{user_id}"
)
# Fetch agent graph
try:
agent = await prisma.models.AgentGraph.prisma().find_unique(
where={"graphVersionId": {"id": agent_id, "version": agent_version}}
)
except prisma.errors.PrismaError as e:
logger.exception("Database error fetching agent")
raise store_exceptions.DatabaseError("Failed to fetch agent") from e
if not agent:
raise store_exceptions.AgentNotFoundError(
f"Agent #{agent_id} v{agent_version} not found"
)
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"
try:
if not (image_url := await store_media.check_media_exists(user_id, filename)):
# Generate agent image as JPEG
image = await store_image_gen.generate_agent_image(agent=agent)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(file=image, filename=filename)
image_url = await store_media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
except Exception as e:
logger.warning(f"Error generating and uploading agent image: {e}")
image_url = None
try:
return await prisma.models.LibraryAgent.prisma().create(
data={
"imageUrl": image_url,
"isCreatedByUser": (user_id == agent.userId),
"isCreatedByUser": (user_id == graph.user_id),
"useGraphIsActiveVersion": True,
"User": {"connect": {"id": user_id}},
# "Creator": {"connect": {"id": agent.userId}},
"Agent": {
"connect": {
"graphVersionId": {"id": agent_id, "version": agent_version}
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
}

View File

@@ -4,14 +4,26 @@ from enum import Enum
import replicate
import replicate.exceptions
import requests
from prisma.models import AgentGraph
from replicate.helpers import FileOutput
from backend.blocks.ideogram import (
AspectRatio,
ColorPalettePreset,
IdeogramModelBlock,
IdeogramModelName,
MagicPromptOption,
StyleType,
UpscaleOption,
)
from backend.data.graph import Graph
from backend.data.model import CredentialsMetaInput, ProviderName
from backend.integrations.credentials_store import ideogram_credentials
from backend.util.request import requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
class ImageSize(str, Enum):
@@ -22,6 +34,63 @@ class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"
def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Ideogram model.
Returns:
str: The URL of the generated image
"""
if not ideogram_credentials.api_key:
raise ValueError("Missing Ideogram API key")
name = graph.name
description = f"{name} ({graph.description})" if graph.description else name
prompt = (
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
f"along with recognizable objects directly associated with the primary function of a {name}. "
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
f"prioritizing clear visual storytelling and thematic clarity above all else."
)
custom_colors = [
"#000030",
"#1C0C47",
"#9900FF",
"#4285F4",
"#FFFFFF",
]
# Run the Ideogram model block with the specified parameters
url = IdeogramModelBlock().run_once(
IdeogramModelBlock.Input(
credentials=CredentialsMetaInput(
id=ideogram_credentials.id,
provider=ProviderName.IDEOGRAM,
title=ideogram_credentials.title,
type=ideogram_credentials.type,
),
prompt=prompt,
ideogram_model_name=IdeogramModelName.V2,
aspect_ratio=AspectRatio.ASPECT_4_3,
magic_prompt_option=MagicPromptOption.OFF,
style_type=StyleType.AUTO,
upscale=UpscaleOption.NO_UPSCALE,
color_palette_name=ColorPalettePreset.NONE,
custom_color_palette=custom_colors,
seed=None,
negative_prompt=None,
),
"result",
credentials=ideogram_credentials,
)
return io.BytesIO(requests.get(url).content)
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.
@@ -33,8 +102,6 @@ async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
io.BytesIO: The generated image as bytes
"""
try:
settings = Settings()
if not settings.secrets.replicate_api_key:
raise ValueError("Missing Replicate API key in settings")
@@ -71,14 +138,12 @@ async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
# If it's a URL string, fetch the image bytes
result_url = output[0]
response = requests.get(result_url)
response.raise_for_status()
image_bytes = response.content
elif isinstance(output, FileOutput):
image_bytes = output.read()
elif isinstance(output, str):
# Output is a URL
response = requests.get(output)
response.raise_for_status()
image_bytes = response.content
else:
raise RuntimeError("Unexpected output format from the model.")

View File

@@ -206,6 +206,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The email address to use for sending emails",
)
use_agent_image_generation_v2: bool = Field(
default=True,
description="Whether to use the new agent image generation service",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:

View File

@@ -9,6 +9,7 @@ const nextConfig = {
"upload.wikimedia.org",
"storage.googleapis.com",
"ideogram.ai", // for generated images
"picsum.photos", // for placeholder images
"dummyimage.com", // for placeholder images
"placekitten.com", // for placeholder images