feat(api): support resizing image on upload

This commit is contained in:
psychedelicious
2025-05-28 20:08:45 +10:00
parent 91db136cd1
commit d5b9c3ee5a

View File

@@ -1,5 +1,7 @@
import io
import json
import traceback
from time import time
from typing import Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
@@ -19,6 +21,8 @@ from invokeai.app.services.image_records.image_records_common import (
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -27,6 +31,11 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
IMAGE_MAX_AGE = 31536000
class Dimensions(BaseModel):
width: int = Field(...)
height: int = Field(...)
@images_router.post(
"/upload",
operation_id="upload_image",
@@ -46,6 +55,11 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
resize_to: Optional[str] = Body(
default=None,
description="Dimensions to resize the image to, must be stringified JSON array of 2 integers",
example="[1024,1024]",
),
metadata: Optional[str] = Body(
default=None,
description="The metadata to associate with the image, must be a stringified JSON dict",
@@ -59,13 +73,33 @@ async def upload_image(
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
if crop_visible:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
if crop_visible:
try:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
except Exception:
raise HTTPException(status_code=500, detail="Failed to crop image")
if resize_to:
try:
dims = json.loads(resize_to)
resize_dims = Dimensions(**dims)
except Exception:
raise HTTPException(status_code=400, detail="Invalid resize_to format")
try:
start_time = time()
np_image = pil_to_np(pil_image)
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
pil_image = np_to_pil(np_image)
print("resize took seconds: ", time() - start_time)
except Exception:
raise HTTPException(status_code=500, detail="Failed to resize image")
extracted_metadata = extract_metadata_from_image(
pil_image=pil_image,
invokeai_metadata_override=metadata,