feat(nodes): wip address metadata feedback

This commit is contained in:
psychedelicious
2023-04-17 21:56:15 +10:00
parent 64f044a984
commit fff55bd991
10 changed files with 174 additions and 65 deletions

View File

@@ -5,12 +5,12 @@ import json
import os
import uuid
from fastapi import Path, Query, Request, UploadFile
from fastapi import HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.metadata import ImageMetadata, InvokeAIMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@@ -18,24 +18,48 @@ from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
) -> FileResponse | Response:
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
# Send only the filename (no relative path shenanigans)
basename = os.path.basename(image_name) # only send the filename
filename = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=basename
)
try:
os.stat(filename)
except FileNotFoundError:
raise HTTPException(status_code=404)
return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
@images_router.get(
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
) -> FileResponse | Response:
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
# Send only the filename (no relative path shenanigans)
basename = os.path.basename(image_name)
filename = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=basename, is_thumbnail=True
)
try:
os.stat(filename)
except FileNotFoundError:
raise HTTPException(status_code=404)
return FileResponse(filename)
@@ -43,27 +67,37 @@ async def get_thumbnail(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
404: {"description": "Session not found"},
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201
status_code=201,
)
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
async def upload_image(
file: UploadFile, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
return Response(status_code=415)
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
return Response(status_code=415)
raise HTTPException(status_code=415, detail="Image reading failed")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
image_path = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
)
# TODO: handle old `sd-metadata` style metadata
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
invokeai_metadata = img.info.get("invokeai", None)
if invokeai_metadata is not None:
invokeai_metadata = InvokeAIMetadata(**json.loads(invokeai_metadata))
# TODO: should creation of this object should happen elsewhere?
res = ImageResponse(
@@ -75,29 +109,31 @@ async def upload_image(file: UploadFile, request: Request, response: Response) -
created=int(os.path.getctime(image_path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
mode=img.mode,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(
image_type, page, per_page
)
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.metadata import InvokeAIMetadata
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@@ -95,16 +96,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[self.id]
invocation = graph_execution_state.execution_graph.get_node(self.id)
metadata = {
"session": context.graph_execution_state_id,
"source_id": source_id,
"invocation": invocation.dict()
}
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, generate_output.image, metadata)
return build_image_output(
@@ -181,7 +177,13 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
@@ -258,7 +260,13 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,

View File

@@ -6,6 +6,8 @@ import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.metadata import InvokeAIMetadata
from ..models.image import ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
@@ -151,7 +153,13 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image_crop
)
@@ -202,7 +210,13 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=new_image
)
@@ -232,7 +246,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
@@ -264,7 +283,12 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
)
@@ -296,7 +320,12 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
)
@@ -333,7 +362,12 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
)

View File

@@ -7,6 +7,7 @@ import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.metadata import InvokeAIMetadata
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from ...backend.model_management.model_manager import ModelManager
@@ -355,7 +356,12 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image, self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,

View File

@@ -4,6 +4,7 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import InvokeAIMetadata
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
@@ -44,7 +45,12 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,

View File

@@ -6,6 +6,7 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import InvokeAIMetadata
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
@@ -49,7 +50,12 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,

View File

@@ -16,10 +16,11 @@ class ImageField(BaseModel):
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
mode: str = Field(description="The image mode (ie pixel format)")
info: dict = Field(description="The image file's metadata")
created: Optional[int] = Field(default=None, description="The creation time of the image")
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
mode: Optional[str] = Field(default=None, description="The image mode (ie pixel format)")
info: Optional[dict] = Field(default=None, description="The image file's metadata")
class Config:
schema_extra = {

View File

@@ -5,18 +5,19 @@ from pydantic import BaseModel, Field
class InvokeAIMetadata(BaseModel):
"""An image's InvokeAI-specific metadata"""
session_id: Optional[str] = Field(description="The session that generated this image")
invocation: Optional[Dict[str, Any]] = Field(
session_id: str = Field(description="The session that generated this image")
invocation: dict = Field(
default={}, description="The prepared invocation that generated this image"
)
class ImageMetadata(BaseModel):
"""An image's general metadata"""
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
mode: str = Field(description="The color mode of the image")
invokeai: Optional[InvokeAIMetadata] = Field(
default={}, description="The image's InvokeAI-specific metadata"
description="The image's InvokeAI-specific metadata"
)

View File

@@ -45,9 +45,10 @@ class EventServiceBase:
)
def emit_invocation_complete(
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
self, graph_execution_state_id: str, result: dict, invocation_dict: dict, source_id: str,
) -> None:
"""Emitted when an invocation has completed"""
print(result)
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
@@ -59,7 +60,7 @@ class EventServiceBase:
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
self, graph_execution_state_id: str, invocation_dict: dict, source_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
@@ -73,7 +74,7 @@ class EventServiceBase:
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
self, graph_execution_state_id: str, invocation_dict: dict, source_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(

View File

@@ -11,6 +11,7 @@ from typing import Any, Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from PIL import PngImagePlugin
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata, InvokeAIMetadata
@@ -41,7 +42,7 @@ class ImageStorageBase(ABC):
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: InvokeAIMetadata | Dict[str, Any] | None = None) -> str:
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: InvokeAIMetadata | None = None) -> str:
pass
@abstractmethod
@@ -99,9 +100,12 @@ class DiskImageStorage(ImageStorageBase):
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
# TODO: handle old `sd-metadata` format
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
invokeai_metadata = img.info.get("invokeai", None)
if invokeai_metadata is not None:
invokeai_metadata = InvokeAIMetadata(**json.loads(invokeai_metadata))
page_of_images.append(
ImageResponse(
@@ -115,6 +119,7 @@ class DiskImageStorage(ImageStorageBase):
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
mode=img.mode,
invokeai=invokeai_metadata
),
)
@@ -154,16 +159,21 @@ class DiskImageStorage(ImageStorageBase):
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
print(metadata)
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, metadata
) # TODO: just pass full path to png writer
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: InvokeAIMetadata | None = None) -> str:
image_subpath = os.path.join(self.__output_folder, image_type)
image_path = os.path.join(image_subpath, image_name)
info = PngImagePlugin.PngInfo()
if metadata:
info.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=info)
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
path=os.path.join(image_subpath, "thumbnails"),
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)