mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(backend): Integrate Ideogram for auto-generating created agent thumbnail (#9568)
### Changes 🏗️ Integrate Ideogram for auto-generating created agent thumbnail  **Note:** switching back to the "old" Replicate-based image generator is possible by setting `USE_AGENT_IMAGE_GENERATION_V2=false`. ### Checklist 📋 #### For code changes: - [ ] I have clearly listed my changes in the PR description - [ ] I have made a test plan - [ ] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [ ] ... <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
@@ -142,6 +142,16 @@ class IdeogramModelBlock(Block):
|
||||
title="Color Palette Preset",
|
||||
advanced=True,
|
||||
)
|
||||
custom_color_palette: Optional[list[str]] = SchemaField(
|
||||
description=(
|
||||
"Only available for model version V_2 or V_2_TURBO. Provide one or more color hex codes "
|
||||
"(e.g., ['#000030', '#1C0C47', '#9900FF', '#4285F4', '#FFFFFF']) to define a custom color "
|
||||
"palette. Only used if 'color_palette_name' is 'NONE'."
|
||||
),
|
||||
default=None,
|
||||
title="Custom Color Palette",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="Generated image URL")
|
||||
@@ -164,6 +174,13 @@ class IdeogramModelBlock(Block):
|
||||
"style_type": StyleType.AUTO,
|
||||
"negative_prompt": None,
|
||||
"color_palette_name": ColorPalettePreset.NONE,
|
||||
"custom_color_palette": [
|
||||
"#000030",
|
||||
"#1C0C47",
|
||||
"#9900FF",
|
||||
"#4285F4",
|
||||
"#FFFFFF",
|
||||
],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
@@ -173,7 +190,7 @@ class IdeogramModelBlock(Block):
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"run_model": lambda api_key, model_name, prompt, seed, aspect_ratio, magic_prompt_option, style_type, negative_prompt, color_palette_name, custom_colors: "https://ideogram.ai/api/images/test-generated-image-url.png",
|
||||
"upscale_image": lambda api_key, image_url: "https://ideogram.ai/api/images/test-upscaled-image-url.png",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -195,6 +212,7 @@ class IdeogramModelBlock(Block):
|
||||
style_type=input_data.style_type.value,
|
||||
negative_prompt=input_data.negative_prompt,
|
||||
color_palette_name=input_data.color_palette_name.value,
|
||||
custom_colors=input_data.custom_color_palette,
|
||||
)
|
||||
|
||||
# Step 2: Upscale the image if requested
|
||||
@@ -217,6 +235,7 @@ class IdeogramModelBlock(Block):
|
||||
style_type: str,
|
||||
negative_prompt: Optional[str],
|
||||
color_palette_name: str,
|
||||
custom_colors: Optional[list[str]],
|
||||
):
|
||||
url = "https://api.ideogram.ai/generate"
|
||||
headers = {
|
||||
@@ -241,7 +260,11 @@ class IdeogramModelBlock(Block):
|
||||
data["image_request"]["negative_prompt"] = negative_prompt
|
||||
|
||||
if color_palette_name != "NONE":
|
||||
data["image_request"]["color_palette"] = {"name": color_palette_name}
|
||||
data["color_palette"] = {"name": color_palette_name}
|
||||
elif custom_colors:
|
||||
data["color_palette"] = {
|
||||
"members": [{"color_hex": color} for color in custom_colors]
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
@@ -267,9 +290,7 @@ class IdeogramModelBlock(Block):
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data={
|
||||
"image_request": "{}", # Empty JSON object
|
||||
},
|
||||
data={"image_request": "{}"},
|
||||
files=files,
|
||||
)
|
||||
|
||||
|
||||
@@ -393,7 +393,8 @@ async def get_graph_all_versions(
|
||||
path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
create_graph: CreateGraph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.GraphModel:
|
||||
graph = graph_db.make_graph_model(create_graph.graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
@@ -401,10 +402,9 @@ async def create_new_graph(
|
||||
graph = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
# Create a library agent for the new graph
|
||||
await library_db.create_library_agent(
|
||||
graph.id,
|
||||
graph.version,
|
||||
user_id,
|
||||
library_agent = await library_db.create_library_agent(graph, user_id)
|
||||
_ = asyncio.create_task(
|
||||
library_db.add_generated_agent_image(graph, library_agent.id)
|
||||
)
|
||||
|
||||
graph = await on_graph_activate(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -7,14 +8,17 @@ import prisma.fields
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.data.includes
|
||||
import backend.server.model
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
import backend.server.v2.store.image_gen as store_image_gen
|
||||
import backend.server.v2.store.media as store_media
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
async def list_library_agents(
|
||||
@@ -168,17 +172,53 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agent") from e
|
||||
|
||||
|
||||
async def add_generated_agent_image(
|
||||
graph: backend.data.graph.GraphModel,
|
||||
library_agent_id: str,
|
||||
) -> Optional[prisma.models.LibraryAgent]:
|
||||
"""
|
||||
Generates an image for the specified LibraryAgent and updates its record.
|
||||
"""
|
||||
user_id = graph.user_id
|
||||
graph_id = graph.id
|
||||
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{graph_id}.jpeg"
|
||||
try:
|
||||
if not (image_url := await store_media.check_media_exists(user_id, filename)):
|
||||
# Generate agent image as JPEG
|
||||
if config.use_agent_image_generation_v2:
|
||||
image = await asyncio.to_thread(
|
||||
store_image_gen.generate_agent_image_v2, graph=graph
|
||||
)
|
||||
else:
|
||||
image = await store_image_gen.generate_agent_image(agent=graph)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(file=image, filename=filename)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating and uploading agent image: {e}")
|
||||
return None
|
||||
|
||||
return await prisma.models.LibraryAgent.prisma().update(
|
||||
where={"id": library_agent_id},
|
||||
data={"imageUrl": image_url},
|
||||
)
|
||||
|
||||
|
||||
async def create_library_agent(
|
||||
agent_id: str,
|
||||
agent_version: int,
|
||||
graph: backend.data.graph.GraphModel,
|
||||
user_id: str,
|
||||
) -> prisma.models.LibraryAgent:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table).
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent to add.
|
||||
agent_version: The version of the agent to add.
|
||||
agent: The agent/Graph to add to the library.
|
||||
user_id: The user to whom the agent will be added.
|
||||
|
||||
Returns:
|
||||
@@ -189,52 +229,19 @@ async def create_library_agent(
|
||||
DatabaseError: If there's an error during creation or if image generation fails.
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating library agent for graph #{agent_id} v{agent_version}; "
|
||||
f"Creating library agent for graph #{graph.id} v{graph.version}; "
|
||||
f"user #{user_id}"
|
||||
)
|
||||
|
||||
# Fetch agent graph
|
||||
try:
|
||||
agent = await prisma.models.AgentGraph.prisma().find_unique(
|
||||
where={"graphVersionId": {"id": agent_id, "version": agent_version}}
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.exception("Database error fetching agent")
|
||||
raise store_exceptions.DatabaseError("Failed to fetch agent") from e
|
||||
|
||||
if not agent:
|
||||
raise store_exceptions.AgentNotFoundError(
|
||||
f"Agent #{agent_id} v{agent_version} not found"
|
||||
)
|
||||
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
try:
|
||||
if not (image_url := await store_media.check_media_exists(user_id, filename)):
|
||||
# Generate agent image as JPEG
|
||||
image = await store_image_gen.generate_agent_image(agent=agent)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(file=image, filename=filename)
|
||||
|
||||
image_url = await store_media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating and uploading agent image: {e}")
|
||||
image_url = None
|
||||
|
||||
try:
|
||||
return await prisma.models.LibraryAgent.prisma().create(
|
||||
data={
|
||||
"imageUrl": image_url,
|
||||
"isCreatedByUser": (user_id == agent.userId),
|
||||
"isCreatedByUser": (user_id == graph.user_id),
|
||||
"useGraphIsActiveVersion": True,
|
||||
"User": {"connect": {"id": user_id}},
|
||||
# "Creator": {"connect": {"id": agent.userId}},
|
||||
"Agent": {
|
||||
"connect": {
|
||||
"graphVersionId": {"id": agent_id, "version": agent_version}
|
||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -4,14 +4,26 @@ from enum import Enum
|
||||
|
||||
import replicate
|
||||
import replicate.exceptions
|
||||
import requests
|
||||
from prisma.models import AgentGraph
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.ideogram import (
|
||||
AspectRatio,
|
||||
ColorPalettePreset,
|
||||
IdeogramModelBlock,
|
||||
IdeogramModelName,
|
||||
MagicPromptOption,
|
||||
StyleType,
|
||||
UpscaleOption,
|
||||
)
|
||||
from backend.data.graph import Graph
|
||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||
from backend.integrations.credentials_store import ideogram_credentials
|
||||
from backend.util.request import requests
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
@@ -22,6 +34,63 @@ class ImageStyle(str, Enum):
|
||||
DIGITAL_ART = "digital art"
|
||||
|
||||
|
||||
def generate_agent_image_v2(graph: Graph | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Ideogram model.
|
||||
Returns:
|
||||
str: The URL of the generated image
|
||||
"""
|
||||
if not ideogram_credentials.api_key:
|
||||
raise ValueError("Missing Ideogram API key")
|
||||
|
||||
name = graph.name
|
||||
description = f"{name} ({graph.description})" if graph.description else name
|
||||
|
||||
prompt = (
|
||||
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
||||
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
||||
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
||||
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
||||
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
||||
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
||||
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||
)
|
||||
|
||||
custom_colors = [
|
||||
"#000030",
|
||||
"#1C0C47",
|
||||
"#9900FF",
|
||||
"#4285F4",
|
||||
"#FFFFFF",
|
||||
]
|
||||
|
||||
# Run the Ideogram model block with the specified parameters
|
||||
url = IdeogramModelBlock().run_once(
|
||||
IdeogramModelBlock.Input(
|
||||
credentials=CredentialsMetaInput(
|
||||
id=ideogram_credentials.id,
|
||||
provider=ProviderName.IDEOGRAM,
|
||||
title=ideogram_credentials.title,
|
||||
type=ideogram_credentials.type,
|
||||
),
|
||||
prompt=prompt,
|
||||
ideogram_model_name=IdeogramModelName.V2,
|
||||
aspect_ratio=AspectRatio.ASPECT_4_3,
|
||||
magic_prompt_option=MagicPromptOption.OFF,
|
||||
style_type=StyleType.AUTO,
|
||||
upscale=UpscaleOption.NO_UPSCALE,
|
||||
color_palette_name=ColorPalettePreset.NONE,
|
||||
custom_color_palette=custom_colors,
|
||||
seed=None,
|
||||
negative_prompt=None,
|
||||
),
|
||||
"result",
|
||||
credentials=ideogram_credentials,
|
||||
)
|
||||
return io.BytesIO(requests.get(url).content)
|
||||
|
||||
|
||||
async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Flux model via Replicate API.
|
||||
@@ -33,8 +102,6 @@ async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
io.BytesIO: The generated image as bytes
|
||||
"""
|
||||
try:
|
||||
settings = Settings()
|
||||
|
||||
if not settings.secrets.replicate_api_key:
|
||||
raise ValueError("Missing Replicate API key in settings")
|
||||
|
||||
@@ -71,14 +138,12 @@ async def generate_agent_image(agent: Graph | AgentGraph) -> io.BytesIO:
|
||||
# If it's a URL string, fetch the image bytes
|
||||
result_url = output[0]
|
||||
response = requests.get(result_url)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif isinstance(output, FileOutput):
|
||||
image_bytes = output.read()
|
||||
elif isinstance(output, str):
|
||||
# Output is a URL
|
||||
response = requests.get(output)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
else:
|
||||
raise RuntimeError("Unexpected output format from the model.")
|
||||
|
||||
@@ -206,6 +206,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The email address to use for sending emails",
|
||||
)
|
||||
|
||||
use_agent_image_generation_v2: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use the new agent image generation service",
|
||||
)
|
||||
|
||||
@field_validator("platform_base_url", "frontend_base_url")
|
||||
@classmethod
|
||||
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
||||
|
||||
@@ -9,6 +9,7 @@ const nextConfig = {
|
||||
"upload.wikimedia.org",
|
||||
"storage.googleapis.com",
|
||||
|
||||
"ideogram.ai", // for generated images
|
||||
"picsum.photos", // for placeholder images
|
||||
"dummyimage.com", // for placeholder images
|
||||
"placekitten.com", // for placeholder images
|
||||
|
||||
Reference in New Issue
Block a user