feat(ui, nodes): metadata wip

This commit is contained in:
psychedelicious
2023-04-14 12:52:07 +10:00
parent 49612d69d0
commit 65f2a7ea31
43 changed files with 600 additions and 305 deletions

View File

@@ -1,6 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
import uuid
from fastapi import Path, Query, Request, UploadFile
@@ -8,6 +10,7 @@ from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@@ -40,32 +43,47 @@ async def get_thumbnail(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
404: {"description": "Session not found"},
},
status_code=201
)
async def upload_image(file: UploadFile, request: Request):
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
if not file.content_type.startswith("image"):
return Response(status_code=415)
contents = await file.read()
try:
im = Image.open(io.BytesIO(contents))
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
return Response(status_code=415)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
return Response(
status_code=201,
headers={
"Location": request.url_for(
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
created=int(os.path.getctime(image_path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
},
)
return res
@images_router.get(
"/",

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class CvInvocationConfig(BaseModel):
@@ -56,7 +56,9 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, image_inpainted, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@@ -9,9 +9,9 @@ from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
@@ -76,6 +76,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model_name = model["model_name"]
outputs = Txt2Img(model).generate(
prompt=self.prompt,
@@ -95,9 +96,22 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[self.id]
invocation = graph_execution_state.execution_graph.get_node(self.id)
metadata = {
"session": context.graph_execution_state_id,
"source_id": source_id,
"invocation": invocation.dict()
}
context.services.images.save(image_type, image_name, generate_output.image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image
)
@@ -144,6 +158,7 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
outputs = Img2Img(model).generate(
prompt=self.prompt,
@@ -168,9 +183,11 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
)
class InpaintInvocation(ImageToImageInvocation):
@@ -218,7 +235,8 @@ class InpaintInvocation(ImageToImageInvocation):
)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
outputs = Inpaint(model).generate(
prompt=self.prompt,
@@ -243,7 +261,9 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
)

View File

@@ -9,7 +9,12 @@ from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
@@ -22,51 +27,70 @@ class PILInvocationConfig(BaseModel):
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
#fmt: off
# fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
#fmt: on
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
"required": [
"type",
"image",
"width",
"height",
]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
image_field = ImageField(image_name=image_name, image_type=image_type)
return ImageOutput(image=image_field, width=image.width, height=image.height)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
#fmt: off
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
#fmt: on
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
"required": [
"type",
"mask",
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
#fmt: off
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
#fmt: on
# # TODO: this isn't really necessary anymore
# class LoadImageInvocation(BaseInvocation):
# """Load an image from a filename and provide it as output."""
# #fmt: off
# type: Literal["load_image"] = "load_image"
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
)
# # Inputs
# image_type: ImageType = Field(description="The type of the image")
# image_name: str = Field(description="The name of the image")
# #fmt: on
# def invoke(self, context: InvocationContext) -> ImageOutput:
# return ImageOutput(
# image_type=self.image_type,
# image_name=self.image_name,
# image=result_image
# )
class ShowImageInvocation(BaseInvocation):
@@ -86,16 +110,17 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
#fmt: off
# fmt: off
type: Literal["crop"] = "crop"
# Inputs
@@ -104,7 +129,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -120,15 +145,16 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, image_crop, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=image_crop
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
#fmt: off
# fmt: off
type: Literal["paste"] = "paste"
# Inputs
@@ -137,7 +163,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@@ -170,21 +196,22 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, new_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=new_image
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
"""Extracts the alpha channel of an image as a mask."""
#fmt: off
# fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@@ -199,22 +226,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask)
context.services.images.save(image_type, image_name, image_mask, self.dict())
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
#fmt: off
# fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -231,22 +258,23 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, blur_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -262,23 +290,24 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, lerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -298,7 +327,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
)

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from ...backend.model_management.model_manager import ModelManager
@@ -19,7 +19,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageField, ImageOutput, build_image_output
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
@@ -405,7 +405,9 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
)

View File

@@ -6,7 +6,7 @@ from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@@ -44,7 +44,9 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@@ -8,7 +8,7 @@ from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class UpscaleInvocation(BaseInvocation):
@@ -49,7 +49,9 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@@ -1,11 +1,14 @@
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
return model_manager.get_model(model_name)
model = model_manager.get_model(model_name)
else:
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
return model_manager.get_model()
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
return model

View File

@@ -1,11 +1,26 @@
from typing import Optional
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
class ImageMetadata(BaseModel):
"""An image's metadata"""
timestamp: float = Field(description="The creation timestamp of the image")
class InvokeAIMetadata(BaseModel):
"""An image's InvokeAI-specific metadata"""
session: Optional[str] = Field(description="The session that generated this image")
source_id: Optional[str] = Field(
description="The source id of the invocation that generated this image"
)
# TODO: figure out metadata
invocation: Optional[Dict[str, Any]] = Field(
default={}, description="The prepared invocation that generated this image"
)
class ImageMetadata(BaseModel):
"""An image's general metadata"""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
invokeai: Optional[InvokeAIMetadata] = Field(
default={}, description="The image's InvokeAI-specific metadata"
)

View File

@@ -25,7 +25,8 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_id: str,
invocation_dict: dict,
source_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@@ -35,7 +36,8 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
progress_image=progress_image,
step=step,
total_steps=total_steps,
@@ -43,40 +45,43 @@ class EventServiceBase:
)
def emit_invocation_complete(
self, graph_execution_state_id: str, invocation_id: str, result: Dict
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_id: str, error: str
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_id: str
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
),
)

View File

@@ -2,16 +2,17 @@
import datetime
import os
import json
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List, Union
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel
from pydantic import BaseModel, Json
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
@@ -42,7 +43,7 @@ class ImageStorageBase(ABC):
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
pass
@abstractmethod
@@ -100,6 +101,8 @@ class DiskImageStorage(ImageStorageBase):
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
page_of_images.append(
ImageResponse(
image_type=image_type.value,
@@ -109,9 +112,10 @@ class DiskImageStorage(ImageStorageBase):
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
timestamp=os.path.getctime(path),
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
)
@@ -150,10 +154,11 @@ class DiskImageStorage(ImageStorageBase):
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
print(metadata)
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
image, "", image_subpath, metadata
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
@@ -162,6 +167,7 @@ class DiskImageStorage(ImageStorageBase):
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
return image_path
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)

View File

@@ -43,10 +43,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node to provide to cliepnts (the prepared node is not as useful)
source_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invocation_dict=invocation.dict(),
source_id=source_id
)
# Invoke
@@ -75,7 +79,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invocation_dict=invocation.dict(),
source_id=source_id,
result=outputs.dict(),
)
@@ -99,7 +104,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invocation_dict=invocation.dict(),
source_id=source_id,
error=error,
)

View File

@@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Union, get_args, get_origin
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults

View File

@@ -1,3 +1,4 @@
from re import S
import torch
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
@@ -20,12 +21,18 @@ def fast_latents_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[id]
invocation = graph_execution_state.execution_graph.get_node(id)
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
graph_execution_state_id=context.graph_execution_state_id,
invocation_dict=invocation.dict(),
source_id=source_id,
progress_image={"width": width, "height": height, "dataURL": dataURL},
step=step,
total_steps=steps,
)

View File

@@ -41,7 +41,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo()
info.add_text("Dream", dream_prompt)
if metadata:
info.add_text("sd-metadata", json.dumps(metadata))
info.add_text("invokeai", json.dumps(metadata))
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
return path

View File

@@ -15,6 +15,7 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
/**
* TODO:

View File

@@ -6,19 +6,31 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import resultsReducer from 'features/gallery/store/resultsSlice';
import galleryReducer, {
GalleryState,
} from 'features/gallery/store/gallerySlice';
import resultsReducer, {
resultsAdapter,
ResultsState,
} from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import lightboxReducer, {
LightboxState,
} from 'features/lightbox/store/lightboxSlice';
import generationReducer, {
GenerationState,
} from 'features/parameters/store/generationSlice';
import postprocessingReducer, {
PostprocessingState,
} from 'features/parameters/store/postprocessingSlice';
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { CanvasState } from 'features/canvas/store/canvasTypes';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@@ -34,13 +46,21 @@ import { socketMiddleware } from 'services/events/middleware';
* The necesssary nested persistors with blacklists are configured below.
*/
const canvasBlacklist = [
/**
* Canvas slice persist blacklist
*/
const canvasBlacklist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
].map((blacklistItem) => `canvas.${blacklistItem}`);
];
const systemBlacklist = [
canvasBlacklist.map((blacklistItem) => `canvas.${blacklistItem}`);
/**
* System slice persist blacklist
*/
const systemBlacklist: (keyof SystemState)[] = [
'currentIteration',
'currentStatus',
'currentStep',
@@ -53,40 +73,101 @@ const systemBlacklist = [
'totalIterations',
'totalSteps',
'openModel',
'cancelOptions.cancelAfter',
'isCancelScheduled',
'sessionId',
].map((blacklistItem) => `system.${blacklistItem}`);
'progressImage',
];
const galleryBlacklist = [
systemBlacklist.map((blacklistItem) => `system.${blacklistItem}`);
/**
* Gallery slice persist blacklist
*/
const galleryBlacklist: (keyof GalleryState)[] = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
].map((blacklistItem) => `gallery.${blacklistItem}`);
];
const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}`
galleryBlacklist.map((blacklistItem) => `gallery.${blacklistItem}`);
/**
* Lightbox slice persist blacklist
*/
const lightboxBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
lightboxBlacklist.map((blacklistItem) => `lightbox.${blacklistItem}`);
/**
* Nodes slice persist blacklist
*/
const nodesBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
nodesBlacklist.map((blacklistItem) => `nodes.${blacklistItem}`);
/**
* Generation slice persist blacklist
*/
const generationBlacklist: (keyof GenerationState)[] = [];
generationBlacklist.map((blacklistItem) => `generation.${blacklistItem}`);
/**
* Postprocessing slice persist blacklist
*/
const postprocessingBlacklist: (keyof PostprocessingState)[] = [];
postprocessingBlacklist.map(
(blacklistItem) => `postprocessing.${blacklistItem}`
);
const nodesBlacklist = ['schema', 'invocations'].map(
(blacklistItem) => `nodes.${blacklistItem}`
);
/**
* Results slice persist blacklist
*
* Currently blacklisting results slice entirely, see persist config below
*/
const resultsBlacklist: (keyof ResultsState)[] = [];
resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
/**
* Uploads slice persist blacklist
*
* Currently blacklisting uploads slice entirely, see persist config below
*/
const uploadsBlacklist: (keyof NodesState)[] = [];
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
/**
* Models slice persist blacklist
*/
const modelsBlacklist: (keyof NodesState)[] = [];
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
/**
* UI slice persist blacklist
*/
const uiBlacklist: (keyof NodesState)[] = [];
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);
const rootReducer = combineReducers({
generation: generationReducer,
postprocessing: postprocessingReducer,
gallery: galleryReducer,
system: systemReducer,
canvas: canvasReducer,
ui: uiReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
results: resultsReducer,
uploads: uploadsReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
results: resultsReducer,
system: systemReducer,
ui: uiReducer,
uploads: uploadsReducer,
});
const rootPersistConfig = getPersistConfig({
@@ -95,13 +176,18 @@ const rootPersistConfig = getPersistConfig({
rootReducer,
blacklist: [
...canvasBlacklist,
...systemBlacklist,
...galleryBlacklist,
...generationBlacklist,
...lightboxBlacklist,
...modelsBlacklist,
...nodesBlacklist,
...postprocessingBlacklist,
// ...resultsBlacklist,
'results',
...systemBlacklist,
...uiBlacklist,
// ...uploadsBlacklist,
'uploads',
// 'nodes',
],
debounce: 300,
});

View File

@@ -66,5 +66,7 @@ export const buildGraph = (state: RootState): BuildGraphOutput => {
nodeIdsToSubscribe.push(Object.keys(node)[0]);
}
console.log('buildGraph: ', graph);
return { graph, nodeIdsToSubscribe };
};

View File

@@ -2,6 +2,14 @@ import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.system.shouldTransformUrls
@@ -11,11 +19,9 @@ export const useGetUrl = () => {
shouldTransformUrls,
getUrl: (url: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
console.log('transformed');
return [OpenAPI.BASE, url].join('/');
}
console.log('didnt transform');
return url;
},
};

View File

@@ -5,6 +5,7 @@ import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { ReactEventHandler } from 'react';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { selectedImageSelector } from '../store/gallerySelectors';

View File

@@ -46,6 +46,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
@@ -167,7 +168,17 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image.metadata, null, 2);
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
return (
<Flex
@@ -193,6 +204,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
@@ -411,37 +453,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width="100%" pt={10}>

View File

@@ -295,7 +295,7 @@ export const gallerySlice = createSlice({
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const location = action.payload;
const { location } = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});

View File

@@ -8,8 +8,14 @@ import {
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import { deserializeImageField } from 'services/util/deserializeImageField';
import {
buildImageUrls,
deserializeImageField,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { getUrlAlt } from 'common/util/getUrl';
import { ImageMetadata } from 'services/api';
// import { deserializeImageField } from 'services/util/deserializeImageField';
// use `createEntityAdapter` to create a slice for results images
@@ -21,7 +27,7 @@ export const resultsAdapter = createEntityAdapter<Image>({
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
// This type is intersected with the Entity type to create the shape of the state
@@ -34,22 +40,31 @@ type AdditionalResultsState = {
nextPage: number; // the next page to request
};
const resultsSlice = createSlice({
name: 'results',
initialState: resultsAdapter.getInitialState<AdditionalResultsState>({
// export type ResultsState = ReturnType<
// typeof resultsAdapter.getInitialState<AdditionalResultsState>
// >;
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
}),
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
// the adapter provides some helper reducers; see the docs for all of them
// can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.addOne,
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
@@ -86,10 +101,34 @@ const resultsSlice = createSlice({
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, invocation, graph_execution_state_id, source_id } = data;
if (isImageOutput(data.result)) {
const resultImage = deserializeImageField(data.result.image);
resultsAdapter.addOne(state, resultImage);
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
invokeai: {
session: graph_execution_state_id,
source_id,
invocation,
},
},
};
// const resultImage = deserializeImageField(result.image, invocation);
resultsAdapter.addOne(state, image);
}
});
},

View File

@@ -12,7 +12,7 @@ import { deserializeImageResponse } from 'services/util/deserializeImageResponse
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
type AdditionalUploadsState = {
@@ -22,6 +22,10 @@ type AdditionalUploadsState = {
nextPage: number;
};
export type UploadssState = ReturnType<
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
>;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
@@ -61,12 +65,17 @@ const uploadsSlice = createSlice({
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const location = action.payload;
const { location, response } = action.payload;
const { image_name, image_url, image_type, metadata, thumbnail_url } =
response;
const uploadedImage = deserializeImageField({
image_name: location.split('/').pop() || '',
image_type: 'uploads',
});
const uploadedImage: Image = {
name: image_name,
url: image_url,
thumbnail: thumbnail_url,
type: 'uploads',
metadata,
};
uploadsAdapter.addOne(state, uploadedImage);
});

View File

@@ -413,7 +413,13 @@ export const systemSlice = createSlice({
* Generator Progress
*/
builder.addCase(generatorProgress, (state, action) => {
const { step, total_steps, progress_image } = action.payload.data;
const {
step,
total_steps,
progress_image,
invocation,
graph_execution_state_id,
} = action.payload.data;
state.currentStatusHasSteps = true;
state.currentStep = step + 1; // TODO: step starts at -1, think this is a bug

View File

@@ -34,6 +34,7 @@ export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
export type { InvokeAIMetadata } from './models/InvokeAIMetadata';
export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField';
@@ -41,7 +42,6 @@ export type { LatentsOutput } from './models/LatentsOutput';
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
export type { LerpInvocation } from './models/LerpInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { ModelsList } from './models/ModelsList';
@@ -91,6 +91,7 @@ export { $InpaintInvocation } from './schemas/$InpaintInvocation';
export { $IntCollectionOutput } from './schemas/$IntCollectionOutput';
export { $IntOutput } from './schemas/$IntOutput';
export { $InverseLerpInvocation } from './schemas/$InverseLerpInvocation';
export { $InvokeAIMetadata } from './schemas/$InvokeAIMetadata';
export { $IterateInvocation } from './schemas/$IterateInvocation';
export { $IterateInvocationOutput } from './schemas/$IterateInvocationOutput';
export { $LatentsField } from './schemas/$LatentsField';
@@ -98,7 +99,6 @@ export { $LatentsOutput } from './schemas/$LatentsOutput';
export { $LatentsToImageInvocation } from './schemas/$LatentsToImageInvocation';
export { $LatentsToLatentsInvocation } from './schemas/$LatentsToLatentsInvocation';
export { $LerpInvocation } from './schemas/$LerpInvocation';
export { $LoadImageInvocation } from './schemas/$LoadImageInvocation';
export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
export { $MaskOutput } from './schemas/$MaskOutput';
export { $ModelsList } from './schemas/$ModelsList';

View File

@@ -17,7 +17,6 @@ import type { IterateInvocation } from './IterateInvocation';
import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
import type { LerpInvocation } from './LerpInvocation';
import type { LoadImageInvocation } from './LoadImageInvocation';
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
import type { MultiplyInvocation } from './MultiplyInvocation';
import type { NoiseInvocation } from './NoiseInvocation';
@@ -39,7 +38,7 @@ export type Graph = {
/**
* The nodes in this graph
*/
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | CvInpaintInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
nodes?: Record<string, (ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | CvInpaintInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/**
* The connections between nodes and their fields in this graph
*/

View File

@@ -2,14 +2,16 @@
/* tslint:disable */
/* eslint-disable */
import type { InvokeAIMetadata } from './InvokeAIMetadata';
/**
* An image's metadata
* An image's general metadata
*/
export type ImageMetadata = {
/**
* The creation timestamp of the image
*/
timestamp: number;
created: number;
/**
* The width of the image in pixels
*/
@@ -19,8 +21,8 @@ export type ImageMetadata = {
*/
height: number;
/**
* The image's SD-specific metadata
* The image's InvokeAI-specific metadata
*/
sd_metadata?: any;
invokeai?: InvokeAIMetadata;
};

View File

@@ -13,5 +13,13 @@ export type ImageOutput = {
* The output image
*/
image: ImageField;
/**
* The width of the image in pixels
*/
width: number;
/**
* The height of the image in pixels
*/
height: number;
};

View File

@@ -0,0 +1,22 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An image's InvokeAI-specific metadata
*/
export type InvokeAIMetadata = {
/**
* The session that generated this image
*/
session?: string;
/**
* The source id of the invocation that generated this image
*/
source_id?: string;
/**
* The prepared invocation that generated this image
*/
invocation?: any;
};

View File

@@ -1,25 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageType } from './ImageType';
/**
* Load an image from a filename and provide it as output.
*/
export type LoadImageInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'load_image';
/**
* The type of the image
*/
image_type: ImageType;
/**
* The name of the image
*/
image_name: string;
};

View File

@@ -12,8 +12,6 @@ export const $Graph = {
contains: {
type: 'one-of',
contains: [{
type: 'LoadImageInvocation',
}, {
type: 'ShowImageInvocation',
}, {
type: 'CropImageInvocation',

View File

@@ -2,9 +2,9 @@
/* tslint:disable */
/* eslint-disable */
export const $ImageMetadata = {
description: `An image's metadata`,
description: `An image's general metadata`,
properties: {
timestamp: {
created: {
type: 'number',
description: `The creation timestamp of the image`,
isRequired: true,
@@ -19,10 +19,12 @@ export const $ImageMetadata = {
description: `The height of the image in pixels`,
isRequired: true,
},
sd_metadata: {
description: `The image's SD-specific metadata`,
properties: {
},
invokeai: {
type: 'all-of',
description: `The image's InvokeAI-specific metadata`,
contains: [{
type: 'InvokeAIMetadata',
}],
},
},
} as const;

View File

@@ -16,5 +16,15 @@ export const $ImageOutput = {
}],
isRequired: true,
},
width: {
type: 'number',
description: `The width of the image in pixels`,
isRequired: true,
},
height: {
type: 'number',
description: `The height of the image in pixels`,
isRequired: true,
},
},
} as const;

View File

@@ -0,0 +1,21 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $InvokeAIMetadata = {
description: `An image's InvokeAI-specific metadata`,
properties: {
session: {
type: 'string',
description: `The session that generated this image`,
},
source_id: {
type: 'string',
description: `The source id of the invocation that generated this image`,
},
invocation: {
description: `The prepared invocation that generated this image`,
properties: {
},
},
},
} as const;

View File

@@ -1,29 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $LoadImageInvocation = {
description: `Load an image from a filename and provide it as output.`,
properties: {
id: {
type: 'string',
description: `The id of this node. Must be unique among all nodes.`,
isRequired: true,
},
type: {
type: 'Enum',
},
image_type: {
type: 'all-of',
description: `The type of the image`,
contains: [{
type: 'ImageType',
}],
isRequired: true,
},
image_name: {
type: 'string',
description: `The name of the image`,
isRequired: true,
},
},
} as const;

View File

@@ -2,6 +2,7 @@
/* tslint:disable */
/* eslint-disable */
import type { Body_upload_image } from '../models/Body_upload_image';
import type { ImageResponse } from '../models/ImageResponse';
import type { ImageType } from '../models/ImageType';
import type { PaginatedResults_ImageResponse_ } from '../models/PaginatedResults_ImageResponse_';
@@ -77,14 +78,14 @@ export class ImagesService {
/**
* Upload Image
* @returns any Successful Response
* @returns ImageResponse The image was uploaded successfully
* @throws ApiError
*/
public static uploadImage({
formData,
}: {
formData: Body_upload_image,
}): CancelablePromise<any> {
}): CancelablePromise<ImageResponse> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/images/uploads/',

View File

@@ -18,7 +18,6 @@ import type { IterateInvocation } from '../models/IterateInvocation';
import type { LatentsToImageInvocation } from '../models/LatentsToImageInvocation';
import type { LatentsToLatentsInvocation } from '../models/LatentsToLatentsInvocation';
import type { LerpInvocation } from '../models/LerpInvocation';
import type { LoadImageInvocation } from '../models/LoadImageInvocation';
import type { MaskFromAlphaInvocation } from '../models/MaskFromAlphaInvocation';
import type { MultiplyInvocation } from '../models/MultiplyInvocation';
import type { NoiseInvocation } from '../models/NoiseInvocation';
@@ -141,7 +140,7 @@ export class SessionsService {
* The id of the session
*/
sessionId: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | CvInpaintInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
requestBody: (ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | CvInpaintInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<string> {
return __request(OpenAPI, {
method: 'POST',
@@ -178,7 +177,7 @@ export class SessionsService {
* The path to the node in the graph
*/
nodePath: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | CvInpaintInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
requestBody: (ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | CvInpaintInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'PUT',

View File

@@ -177,25 +177,25 @@ export const socketMiddleware = () => {
// Set up listeners for the present subscription
socket.on('invocation_started', (data) => {
if (shouldHandleEvent(data.source_id)) {
if (shouldHandleEvent(data.invocation.id)) {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
}
});
socket.on('generator_progress', (data) => {
if (shouldHandleEvent(data.source_id)) {
if (shouldHandleEvent(data.invocation.id)) {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_error', (data) => {
if (shouldHandleEvent(data.source_id)) {
if (shouldHandleEvent(data.invocation.id)) {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_complete', (data) => {
if (shouldHandleEvent(data.source_id)) {
if (shouldHandleEvent(data.invocation.id)) {
const sessionId = data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;

View File

@@ -13,6 +13,12 @@ export type AnyInvocationType = NonNullable<
NonNullable<Graph['nodes']>[string]['type']
>;
export type AnyInvocation = {
id: string;
type: AnyInvocationType | string;
[key: string]: any;
};
export type AnyResult = GraphExecutionState['results'][string];
/**
@@ -22,7 +28,7 @@ export type AnyResult = GraphExecutionState['results'][string];
*/
export type GeneratorProgressEvent = {
graph_execution_state_id: string;
invocation: AnyInvocationType;
invocation: AnyInvocation;
source_id: string;
progress_image?: ProgressImage;
step: number;
@@ -38,6 +44,7 @@ export type GeneratorProgressEvent = {
*/
export type InvocationCompleteEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
source_id: string;
result: AnyResult;
};
@@ -49,7 +56,7 @@ export type InvocationCompleteEvent = {
*/
export type InvocationErrorEvent = {
graph_execution_state_id: string;
invocation: AnyInvocationType;
invocation: AnyInvocation;
source_id: string;
error: string;
};
@@ -61,8 +68,8 @@ export type InvocationErrorEvent = {
*/
export type InvocationStartedEvent = {
graph_execution_state_id: string;
invocation: AnyInvocationType;
source_id: string;
invocation: AnyInvocation;
};
/**

View File

@@ -26,7 +26,7 @@ export const imageUploaded = createAppAsyncThunk(
async (arg: ImageUploadedArg, _thunkApi) => {
const response = await ImagesService.uploadImage(arg);
const { location } = getHeaders(response);
return location;
return { response, location };
}
);

View File

@@ -1,5 +1,6 @@
import { Image } from 'app/invokeai';
import { ImageField, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
export const buildImageUrls = (
imageType: ImageType,
@@ -34,23 +35,21 @@ export const extractTimestampFromImageName = (imageName: string) => {
* TODO: do some more janky stuff here to get image dimensions instead of defaulting to 512x512?
* TODO: better yet, improve the nodes server (wip)
*/
export const deserializeImageField = (image: ImageField): Image => {
export const deserializeImageField = (
image: ImageField,
invocation: AnyInvocation
): Image => {
const name = image.image_name;
const type = image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
return {
name,
type,
url,
thumbnail,
metadata: {
timestamp,
height: 512,
width: 512,
},
timestamp,
metadata: invocation,
};
};

View File

@@ -10,6 +10,8 @@ export const deserializeImageResponse = (
const { image_name, image_type, image_url, metadata, thumbnail_url } =
imageResponse;
const { width, height, timestamp, invokeai } = metadata;
// TODO: parse metadata - just leaving it as-is for now
return {
@@ -17,6 +19,7 @@ export const deserializeImageResponse = (
type: image_type,
url: image_url,
thumbnail: thumbnail_url,
metadata,
timestamp,
metadata: invokeai,
};
};