mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(store): Generate AI images for store submissions (#9090)
Allow generating ai images for store submissions
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
import io
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import replicate
|
||||
import replicate.exceptions
|
||||
import requests
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
LANDSCAPE = "1024x768"
|
||||
|
||||
|
||||
class ImageStyle(str, Enum):
|
||||
DIGITAL_ART = "digital art"
|
||||
|
||||
|
||||
async def generate_agent_image(agent: Graph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Flux model via Replicate API.
|
||||
|
||||
Args:
|
||||
agent (Graph): The agent to generate an image for
|
||||
|
||||
Returns:
|
||||
io.BytesIO: The generated image as bytes
|
||||
"""
|
||||
try:
|
||||
settings = Settings()
|
||||
|
||||
if not settings.secrets.replicate_api_key:
|
||||
raise ValueError("Missing Replicate API key in settings")
|
||||
|
||||
# Construct prompt from agent details
|
||||
prompt = f"App store image for AI agent that gives a cool visual representation of what the agent does: - {agent.name} - {agent.description}"
|
||||
|
||||
# Set up Replicate client
|
||||
client = replicate.Client(api_token=settings.secrets.replicate_api_key)
|
||||
|
||||
# Model parameters
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"aspect_ratio": "4:3",
|
||||
"output_format": "jpg",
|
||||
"output_quality": 90,
|
||||
"num_inference_steps": 30,
|
||||
"guidance": 3.5,
|
||||
"negative_prompt": "blurry, low quality, distorted, deformed",
|
||||
"disable_safety_checker": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Run model
|
||||
output = client.run("black-forest-labs/flux-pro", input=input_data)
|
||||
|
||||
# Depending on the model output, extract the image URL or bytes
|
||||
# If the output is a list of FileOutput or URLs
|
||||
if isinstance(output, list) and output:
|
||||
if isinstance(output[0], FileOutput):
|
||||
image_bytes = output[0].read()
|
||||
else:
|
||||
# If it's a URL string, fetch the image bytes
|
||||
result_url = output[0]
|
||||
response = requests.get(result_url)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif isinstance(output, FileOutput):
|
||||
image_bytes = output.read()
|
||||
elif isinstance(output, str):
|
||||
# Output is a URL
|
||||
response = requests.get(output)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
else:
|
||||
raise RuntimeError("Unexpected output format from the model.")
|
||||
|
||||
return io.BytesIO(image_bytes)
|
||||
|
||||
except replicate.exceptions.ReplicateError as e:
|
||||
if e.status == 401:
|
||||
raise RuntimeError("Invalid Replicate API token") from e
|
||||
raise RuntimeError(f"Replicate API error: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate agent image")
|
||||
raise RuntimeError(f"Image generation failed: {str(e)}")
|
||||
@@ -15,7 +15,45 @@ ALLOWED_VIDEO_TYPES = {"video/mp4", "video/webm"}
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
|
||||
async def check_media_exists(user_id: str, filename: str) -> str | None:
|
||||
"""
|
||||
Check if a media file exists in storage for the given user.
|
||||
Tries both images and videos directories.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user who uploaded the file
|
||||
filename (str): Name of the file to check
|
||||
|
||||
Returns:
|
||||
str | None: URL of the blob if it exists, None otherwise
|
||||
"""
|
||||
try:
|
||||
settings = Settings()
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
|
||||
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
image_blob = bucket.blob(image_path)
|
||||
if image_blob.exists():
|
||||
return image_blob.public_url
|
||||
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
|
||||
video_blob = bucket.blob(video_path)
|
||||
if video_blob.exists():
|
||||
return video_blob.public_url
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if media file exists: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def upload_media(
|
||||
user_id: str, file: fastapi.UploadFile, use_file_name: bool = False
|
||||
) -> str:
|
||||
|
||||
# Get file content for deeper validation
|
||||
try:
|
||||
@@ -84,6 +122,9 @@ async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
|
||||
try:
|
||||
# Validate file type
|
||||
content_type = file.content_type
|
||||
if content_type is None:
|
||||
content_type = "image/jpeg"
|
||||
|
||||
if (
|
||||
content_type not in ALLOWED_IMAGE_TYPES
|
||||
and content_type not in ALLOWED_VIDEO_TYPES
|
||||
@@ -119,7 +160,10 @@ async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
|
||||
# Generate unique filename
|
||||
filename = file.filename or ""
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
if use_file_name:
|
||||
unique_filename = filename
|
||||
else:
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
|
||||
# Construct storage path
|
||||
media_type = "images" if content_type in ALLOWED_IMAGE_TYPES else "videos"
|
||||
|
||||
@@ -6,7 +6,9 @@ import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
|
||||
@@ -439,3 +441,63 @@ async def upload_submission_media(
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail=f"Failed to upload media file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def generate_image(
|
||||
agent_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.responses.Response:
|
||||
"""
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
"""
|
||||
try:
|
||||
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(
|
||||
agent=agent
|
||||
)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred whilst generating submission image")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail=f"Failed to generate image: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ interface PublishAgentInfoProps {
|
||||
) => void;
|
||||
onClose: () => void;
|
||||
initialData?: {
|
||||
agent_id: string;
|
||||
title: string;
|
||||
subheader: string;
|
||||
slug: string;
|
||||
@@ -36,6 +37,7 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
onClose,
|
||||
initialData,
|
||||
}) => {
|
||||
const [agentId, setAgentId] = React.useState<string | null>(null);
|
||||
const [images, setImages] = React.useState<string[]>(
|
||||
initialData?.additionalImages
|
||||
? [initialData.thumbnailSrc, ...initialData.additionalImages]
|
||||
@@ -59,10 +61,10 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
);
|
||||
const [slug, setSlug] = React.useState(initialData?.slug || "");
|
||||
const thumbnailsContainerRef = React.useRef<HTMLDivElement | null>(null);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (initialData) {
|
||||
setImages(initialData.additionalImages || []);
|
||||
setAgentId(initialData.agent_id);
|
||||
setImagesWithValidation(initialData.additionalImages || []);
|
||||
setSelectedImage(initialData.thumbnailSrc || null);
|
||||
setTitle(initialData.title);
|
||||
setSubheader(initialData.subheader);
|
||||
@@ -73,10 +75,18 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
}
|
||||
}, [initialData]);
|
||||
|
||||
const setImagesWithValidation = (newImages: string[]) => {
|
||||
// Remove duplicates
|
||||
const uniqueImages = Array.from(new Set(newImages));
|
||||
// Keep only first 5 images
|
||||
const limitedImages = uniqueImages.slice(0, 5);
|
||||
setImages(limitedImages);
|
||||
};
|
||||
|
||||
const handleRemoveImage = (indexToRemove: number) => {
|
||||
const newImages = [...images];
|
||||
newImages.splice(indexToRemove, 1);
|
||||
setImages(newImages);
|
||||
setImagesWithValidation(newImages);
|
||||
if (newImages[indexToRemove] === selectedImage) {
|
||||
setSelectedImage(newImages[0] || null);
|
||||
}
|
||||
@@ -88,6 +98,8 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
};
|
||||
|
||||
const handleAddImage = async () => {
|
||||
if (images.length >= 5) return;
|
||||
|
||||
const input = document.createElement("input");
|
||||
input.type = "file";
|
||||
input.accept = "image/*";
|
||||
@@ -115,11 +127,7 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
"$1",
|
||||
);
|
||||
|
||||
setImages((prev) => {
|
||||
const newImages = [...prev, imageUrl];
|
||||
console.log("Added image. Images now:", newImages);
|
||||
return newImages;
|
||||
});
|
||||
setImagesWithValidation([...images, imageUrl]);
|
||||
if (!selectedImage) {
|
||||
setSelectedImage(imageUrl);
|
||||
}
|
||||
@@ -128,6 +136,27 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
const [isGenerating, setIsGenerating] = React.useState(false);
|
||||
|
||||
const handleGenerateImage = async () => {
|
||||
if (isGenerating || images.length >= 5) return;
|
||||
|
||||
setIsGenerating(true);
|
||||
try {
|
||||
const api = new BackendAPI();
|
||||
if (!agentId) {
|
||||
throw new Error("Agent ID is required");
|
||||
}
|
||||
const { image_url } = await api.generateStoreSubmissionImage(agentId);
|
||||
console.log("image_url", image_url);
|
||||
setImagesWithValidation([...images, image_url]);
|
||||
} catch (error) {
|
||||
console.error("Failed to generate image:", error);
|
||||
} finally {
|
||||
setIsGenerating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.preventDefault();
|
||||
onSubmit(title, subheader, slug, description, images, youtubeLink, [
|
||||
@@ -284,19 +313,21 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
<Button
|
||||
onClick={handleAddImage}
|
||||
variant="ghost"
|
||||
className="flex h-[70px] w-[100px] flex-col items-center justify-center rounded-md bg-neutral-200 hover:bg-neutral-300 dark:bg-neutral-700 dark:hover:bg-neutral-600"
|
||||
>
|
||||
<IconPlus
|
||||
size="lg"
|
||||
className="text-neutral-600 dark:text-neutral-300"
|
||||
/>
|
||||
<span className="mt-1 font-['Geist'] text-xs font-normal text-neutral-600 dark:text-neutral-300">
|
||||
Add image
|
||||
</span>
|
||||
</Button>
|
||||
{images.length < 5 && (
|
||||
<Button
|
||||
onClick={handleAddImage}
|
||||
variant="ghost"
|
||||
className="flex h-[70px] w-[100px] flex-col items-center justify-center rounded-md bg-neutral-200 hover:bg-neutral-300 dark:bg-neutral-700 dark:hover:bg-neutral-600"
|
||||
>
|
||||
<IconPlus
|
||||
size="lg"
|
||||
className="text-neutral-600 dark:text-neutral-300"
|
||||
/>
|
||||
<span className="mt-1 font-['Geist'] text-xs font-normal text-neutral-600 dark:text-neutral-300">
|
||||
Add image
|
||||
</span>
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@@ -313,9 +344,17 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="bg-neutral-800 text-white hover:bg-neutral-900 dark:bg-neutral-600 dark:hover:bg-neutral-500"
|
||||
className={`bg-neutral-800 text-white hover:bg-neutral-900 dark:bg-neutral-600 dark:hover:bg-neutral-500 ${
|
||||
images.length >= 5 ? "cursor-not-allowed opacity-50" : ""
|
||||
}`}
|
||||
onClick={handleGenerateImage}
|
||||
disabled={isGenerating || images.length >= 5}
|
||||
>
|
||||
Generate
|
||||
{isGenerating
|
||||
? "Generating..."
|
||||
: images.length >= 5
|
||||
? "Max images reached"
|
||||
: "Generate"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -45,6 +45,7 @@ export const PublishAgentPopout: React.FC<PublishAgentPopoutProps> = ({
|
||||
const [myAgents, setMyAgents] = React.useState<MyAgentsResponse | null>(null);
|
||||
const [selectedAgent, setSelectedAgent] = React.useState<string | null>(null);
|
||||
const [initialData, setInitialData] = React.useState<{
|
||||
agent_id: string;
|
||||
title: string;
|
||||
subheader: string;
|
||||
slug: string;
|
||||
@@ -119,6 +120,7 @@ export const PublishAgentPopout: React.FC<PublishAgentPopoutProps> = ({
|
||||
const name = selectedAgentData?.agent_name || "";
|
||||
const description = selectedAgentData?.description || "";
|
||||
setInitialData({
|
||||
agent_id: agentId,
|
||||
title: name,
|
||||
subheader: "",
|
||||
description: description,
|
||||
|
||||
@@ -299,6 +299,15 @@ export default class BackendAPI {
|
||||
return this._request("POST", "/store/submissions", submission);
|
||||
}
|
||||
|
||||
generateStoreSubmissionImage(
|
||||
agent_id: string,
|
||||
): Promise<{ image_url: string }> {
|
||||
return this._request(
|
||||
"POST",
|
||||
"/store/submissions/generate_image?agent_id=" + agent_id,
|
||||
);
|
||||
}
|
||||
c;
|
||||
deleteStoreSubmission(submission_id: string): Promise<boolean> {
|
||||
return this._request("DELETE", `/store/submissions/${submission_id}`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user