feat(nodes): add metadata module + tests, thumbnails

- `MetadataModule` is stateless and needed in places where the `InvocationContext` is not available, so have not made it a `service`
- Handles loading/parsing/building metadata, and creating png info objects
- added tests for MetadataModule
- Lifted thumbnail stuff to util
This commit is contained in:
psychedelicious
2023-04-19 23:49:40 +10:00
parent 2b53ce50e0
commit 162bcda49e
23 changed files with 638 additions and 223 deletions

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.modules.metadata import ImageMetadata
class ImageResponse(BaseModel):

View File

@@ -10,7 +10,7 @@ from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.metadata import ImageMetadata, InvokeAIMetadata
from invokeai.app.modules.metadata import ImageMetadata, InvokeAIMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType

View File

@@ -6,7 +6,12 @@ import numpy as np
import numpy.random
from pydantic import Field
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
class IntCollectionOutput(BaseInvocationOutput):
@@ -41,12 +46,14 @@ class RandomRangeInvocation(BaseInvocation):
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
description="The seed for the RNG, provide None or -1 for random",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
)

View File

@@ -8,7 +8,7 @@ from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import InvokeAIMetadata
from invokeai.app.modules.metadata import MetadataModule
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
@@ -58,9 +58,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)

View File

@@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.metadata import InvokeAIMetadata
from invokeai.app.modules.metadata import MetadataModule
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@@ -58,7 +58,10 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState,
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
@@ -72,7 +75,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate(
@@ -93,13 +98,14 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, generate_output.image, metadata)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
@@ -123,8 +129,11 @@ class ImageToImageInvocation(TextToImageInvocation):
)
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
@@ -146,18 +155,20 @@ class ImageToImageInvocation(TextToImageInvocation):
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
@@ -172,10 +183,9 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
@@ -185,6 +195,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@@ -200,7 +211,10 @@ class InpaintInvocation(ImageToImageInvocation):
)
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
@@ -227,18 +241,20 @@ class InpaintInvocation(ImageToImageInvocation):
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
@@ -254,9 +270,8 @@ class InpaintInvocation(ImageToImageInvocation):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, result_image, metadata)

View File

@@ -6,7 +6,7 @@ import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.metadata import InvokeAIMetadata
from invokeai.app.modules.metadata import MetadataModule
from ..models.image import ImageField, ImageType
from .baseinvocation import (
@@ -151,13 +151,8 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id, invocation=self.dict()
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
@@ -214,13 +209,8 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id, invocation=self.dict()
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
@@ -256,9 +246,10 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id, invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
@@ -292,9 +283,10 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id, invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
@@ -328,9 +320,10 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id, invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
@@ -369,9 +362,10 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id, invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image

View File

@@ -5,9 +5,9 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.metadata import InvokeAIMetadata
from invokeai.app.modules.metadata import MetadataModule
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
@@ -358,10 +358,10 @@ class LatentsToImageInvocation(BaseInvocation):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,

View File

@@ -1,11 +1,10 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import InvokeAIMetadata
from ..services.invocation_services import InvocationServices
from invokeai.app.modules.metadata import MetadataModule
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
@@ -46,10 +45,10 @@ class RestoreFaceInvocation(BaseInvocation):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,

View File

@@ -1,13 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import InvokeAIMetadata
from ..services.invocation_services import InvocationServices
from invokeai.app.modules.metadata import MetadataModule
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
@@ -51,10 +49,10 @@ class UpscaleInvocation(BaseInvocation):
context.graph_execution_state_id, self.id
)
metadata = InvokeAIMetadata(
session_id=context.graph_execution_state_id,
invocation=self.dict()
metadata = MetadataModule.build_metadata(
session_id=context.graph_execution_state_id, invocation=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,

View File

@@ -1,23 +0,0 @@
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
class InvokeAIMetadata(BaseModel):
"""An image's InvokeAI-specific metadata"""
session_id: str = Field(description="The session that generated this image")
invocation: dict = Field(
default={}, description="The prepared invocation that generated this image"
)
class ImageMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
mode: str = Field(description="The color mode of the image")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)

View File

@@ -0,0 +1,180 @@
import json
from typing import Any, Dict, Literal, Optional
from PIL import Image, PngImagePlugin
from pydantic import (
BaseModel,
Extra,
Field,
StrictBool,
StrictInt,
StrictStr,
ValidationError,
root_validator,
)
from invokeai.app.models.image import ImageType
class MetadataImageField(BaseModel):
"""A non-nullable version of ImageField"""
image_type: Literal[tuple([t.value for t in ImageType])] # type: ignore
image_name: StrictStr
class MetadataLatentsField(BaseModel):
"""A non-nullable version of LatentsField"""
latents_name: StrictStr
# Union of all valid metadata field types - use mostly strict types
NodeMetadataFieldTypes = (
StrictStr | StrictInt | float | StrictBool # we want to cast ints to floats here
)
class NodeMetadataField(BaseModel):
"""Helper class used as a hack for arbitrary metadata field keys."""
__root__: Dict[StrictStr, NodeMetadataFieldTypes]
# `extra=Extra.allow` allows this to model any potential node with `id` and `type` fields
class NodeMetadata(BaseModel, extra=Extra.allow):
"""Node metadata model, used for validation of metadata."""
@root_validator
def validate_node_metadata(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Parses the node metadata, ignoring invalid values"""
parsed: dict[str, Any] = {}
# Conditionally build the parsed metadata, silently skipping invalid values
for name, value in values.items():
# explicitly parse `id` and `type` as strings
if name == "id":
if type(value) is not str:
continue
parsed[name] = value
elif name == "type":
if type(value) is not str:
continue
parsed[name] = value
else:
try:
if type(value) is dict:
# we only allow certain dicts, else just ignore the value entirely
if "image_name" in value or "image_type" in value:
# parse as an ImageField
parsed[name] = MetadataImageField.parse_obj(value)
elif "latents_name" in value:
# this is a LatentsField
parsed[name] = MetadataLatentsField.parse_obj(value)
else:
# hack to get parse and validate arbitrary keys
NodeMetadataField.parse_obj({name: value})
parsed[name] = value
except ValidationError:
# TODO: do we want to somehow alert when metadata is not fully valid?
continue
return parsed
class InvokeAIMetadata(BaseModel):
session_id: Optional[StrictStr] = Field(
description="The session in which this image was created"
)
node: Optional[NodeMetadata] = Field(description="The node that created this image")
@root_validator(pre=True)
def validate_invokeai_metadata(cls, values: dict[str, Any]) -> dict[str, Any]:
parsed: dict[str, Any] = {}
# Conditionally build the parsed metadata, silently skipping invalid values
for name, value in values.items():
if name == "session_id":
if type(value) is not str:
continue
parsed[name] = value
elif name == "node":
try:
p = NodeMetadata.parse_obj(value)
# check for empty NodeMetadata object
if len(p.dict().items()) == 0:
continue
except ValidationError:
continue
parsed[name] = value
return parsed
class ImageMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
mode: str = Field(description="The color mode of the image")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
class MetadataModule:
"""Handles loading metadata from images and parsing it."""
# TODO: Support parsing old format metadata **hurk**
@staticmethod
def _load_metadata(image: Image.Image, key="invokeai") -> Any:
"""Loads a specific info entry from a PIL Image."""
raw_metadata = image.info.get(key)
# metadata should always be a dict
if type(raw_metadata) is not str:
return None
loaded_metadata = json.loads(raw_metadata)
return loaded_metadata
@staticmethod
def _parse_invokeai_metadata(
metadata: Any,
) -> InvokeAIMetadata | None:
"""Parses an object as InvokeAI metadata."""
if type(metadata) is not dict:
return None
parsed_metadata = InvokeAIMetadata.parse_obj(metadata)
return parsed_metadata
@staticmethod
def get_metadata(image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
loaded_metadata = MetadataModule._load_metadata(image)
parsed_metadata = MetadataModule._parse_invokeai_metadata(loaded_metadata)
return parsed_metadata
@staticmethod
def build_metadata(
session_id: StrictStr, invocation: BaseModel
) -> InvokeAIMetadata:
"""Builds an InvokeAIMetadata object"""
metadata = InvokeAIMetadata(
session_id=session_id, node=NodeMetadata(**invocation.dict())
)
return metadata
@staticmethod
def build_png_info(metadata: InvokeAIMetadata | None):
png_info = PngImagePlugin.PngInfo()
if metadata is not None:
png_info.add_text("invokeai", metadata.json())
return png_info

View File

@@ -14,9 +14,10 @@ import PIL.Image as PILImage
from PIL import PngImagePlugin
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata, InvokeAIMetadata
from invokeai.app.modules.metadata import ImageMetadata, InvokeAIMetadata, MetadataModule
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.get_timestamp import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from invokeai.backend.image_util import PngWriter
@@ -113,11 +114,7 @@ class DiskImageStorage(ImageStorageBase):
filename = os.path.basename(path)
img = PILImage.open(path)
# TODO: handle old `sd-metadata` format
invokeai_metadata = img.info.get("invokeai", None)
if invokeai_metadata is not None:
invokeai_metadata = InvokeAIMetadata(**json.loads(invokeai_metadata))
invokeai_metadata = MetadataModule.get_metadata(img)
page_of_images.append(
ImageResponse(
@@ -188,22 +185,18 @@ class DiskImageStorage(ImageStorageBase):
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: InvokeAIMetadata | None = None) -> Tuple[str, str, int]:
image_path = self.get_path(image_type, image_name)
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
info = PngImagePlugin.PngInfo()
png_info = MetadataModule.build_png_info(metadata=metadata)
if metadata:
info.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=png_info)
image.save(image_path, "PNG", pnginfo=info)
thumbnail = image.copy()
thumbnail.thumbnail(size=(256, 256))
thumbnail.save(thumbnail_path, "WEBP")
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail)
self.__set_cache(thumbnail_path, thumbnail_image)
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))

View File

@@ -1,27 +0,0 @@
import os
from PIL import Image
from invokeai.app.models.image import ImageType
def save_thumbnail(
image: Image.Image,
filename: str,
image_type: ImageType,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename =
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

View File

@@ -0,0 +1,15 @@
import os
from PIL import Image
def get_thumbnail_name(image_name: str) -> str:
"""Formats given an image name, returns the appropriate thumbnail image name"""
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
return thumbnail_name
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
"""Makes a thumbnail from a PIL Image"""
thumbnail = image.copy()
thumbnail.thumbnail(size=(size, size))
return thumbnail

View File

@@ -139,7 +139,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
});
const sessionId = image.metadata.invokeai?.session_id;
const invocation = image.metadata.invokeai?.invocation;
const node = image.metadata.invokeai?.node as Record<string, any>;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
@@ -170,105 +170,101 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{Object.keys(invocation).length > 0 ? (
{node && Object.keys(node).length > 0 ? (
<>
{invocation.type && (
<MetadataItem label="Invocation type" value={invocation.type} />
{node.type && (
<MetadataItem label="Invocation type" value={node.type} />
)}
{invocation.model && (
<MetadataItem label="Model" value={invocation.model} />
)}
{invocation.prompt && (
{node.model && <MetadataItem label="Model" value={node.model} />}
{node.prompt && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof invocation.prompt === 'string'
? invocation.prompt
: promptToString(invocation.prompt)
typeof node.prompt === 'string'
? node.prompt
: promptToString(node.prompt)
}
onClick={() => setBothPrompts(invocation.prompt)}
onClick={() => setBothPrompts(node.prompt)}
/>
)}
{invocation.seed !== undefined && (
{node.seed !== undefined && (
<MetadataItem
label="Seed"
value={invocation.seed}
onClick={() => dispatch(setSeed(invocation.seed))}
value={node.seed}
onClick={() => dispatch(setSeed(node.seed))}
/>
)}
{invocation.threshold !== undefined && (
{node.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={invocation.threshold}
onClick={() => dispatch(setThreshold(invocation.threshold))}
value={node.threshold}
onClick={() => dispatch(setThreshold(node.threshold))}
/>
)}
{invocation.perlin !== undefined && (
{node.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={invocation.perlin}
onClick={() => dispatch(setPerlin(invocation.perlin))}
value={node.perlin}
onClick={() => dispatch(setPerlin(node.perlin))}
/>
)}
{invocation.scheduler && (
{node.scheduler && (
<MetadataItem
label="Sampler"
value={invocation.scheduler}
onClick={() => dispatch(setSampler(invocation.scheduler))}
value={node.scheduler}
onClick={() => dispatch(setSampler(node.scheduler))}
/>
)}
{invocation.steps && (
{node.steps && (
<MetadataItem
label="Steps"
value={invocation.steps}
onClick={() => dispatch(setSteps(invocation.steps))}
value={node.steps}
onClick={() => dispatch(setSteps(node.steps))}
/>
)}
{invocation.cfg_scale !== undefined && (
{node.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={invocation.cfg_scale}
onClick={() => dispatch(setCfgScale(invocation.cfg_scale))}
value={node.cfg_scale}
onClick={() => dispatch(setCfgScale(node.cfg_scale))}
/>
)}
{invocation.variations && invocation.variations.length > 0 && (
{node.variations && node.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(invocation.variations)}
value={seedWeightsToString(node.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(invocation.variations))
)
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
}
/>
)}
{invocation.seamless && (
{node.seamless && (
<MetadataItem
label="Seamless"
value={invocation.seamless}
onClick={() => dispatch(setSeamless(invocation.seamless))}
value={node.seamless}
onClick={() => dispatch(setSeamless(node.seamless))}
/>
)}
{invocation.hires_fix && (
{node.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={invocation.hires_fix}
onClick={() => dispatch(setHiresFix(invocation.hires_fix))}
value={node.hires_fix}
onClick={() => dispatch(setHiresFix(node.hires_fix))}
/>
)}
{invocation.width && (
{node.width && (
<MetadataItem
label="Width"
value={invocation.width}
onClick={() => dispatch(setWidth(invocation.width))}
value={node.width}
onClick={() => dispatch(setWidth(node.width))}
/>
)}
{invocation.height && (
{node.height && (
<MetadataItem
label="Height"
value={invocation.height}
onClick={() => dispatch(setHeight(invocation.height))}
value={node.height}
onClick={() => dispatch(setHeight(node.height))}
/>
)}
{/* {init_image_path && (
@@ -279,20 +275,18 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{invocation.strength && (
{node.strength && (
<MetadataItem
label="Image to image strength"
value={invocation.strength}
onClick={() => dispatch(setImg2imgStrength(invocation.strength))}
value={node.strength}
onClick={() => dispatch(setImg2imgStrength(node.strength))}
/>
)}
{invocation.fit && (
{node.fit && (
<MetadataItem
label="Image to image fit"
value={invocation.fit}
onClick={() =>
dispatch(setShouldFitToWidthHeight(invocation.fit))
}
value={node.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
/>
)}
</>

View File

@@ -47,6 +47,7 @@ export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { ModelsList } from './models/ModelsList';
export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NodeMetadata } from './models/NodeMetadata';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
@@ -106,6 +107,7 @@ export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
export { $MaskOutput } from './schemas/$MaskOutput';
export { $ModelsList } from './schemas/$ModelsList';
export { $MultiplyInvocation } from './schemas/$MultiplyInvocation';
export { $NodeMetadata } from './schemas/$NodeMetadata';
export { $NoiseInvocation } from './schemas/$NoiseInvocation';
export { $NoiseOutput } from './schemas/$NoiseOutput';
export { $PaginatedResults_GraphExecutionState_ } from './schemas/$PaginatedResults_GraphExecutionState_';

View File

@@ -2,17 +2,16 @@
/* tslint:disable */
/* eslint-disable */
/**
* An image's InvokeAI-specific metadata
*/
import type { NodeMetadata } from './NodeMetadata';
export type InvokeAIMetadata = {
/**
* The session that generated this image
* The session in which this image was created
*/
session_id: string;
session_id?: string;
/**
* The prepared invocation that generated this image
* The node that created this image
*/
invocation?: any;
node?: NodeMetadata;
};

View File

@@ -0,0 +1,10 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Node metadata model, used for validation of metadata.
*/
export type NodeMetadata = {
};

View File

@@ -24,7 +24,7 @@ export type RandomRangeInvocation = {
*/
size?: number;
/**
* The seed for the RNG
* The seed for the RNG, provide None or -1 for random
*/
seed?: number;
};

View File

@@ -2,17 +2,17 @@
/* tslint:disable */
/* eslint-disable */
export const $InvokeAIMetadata = {
description: `An image's InvokeAI-specific metadata`,
properties: {
session_id: {
type: 'string',
description: `The session that generated this image`,
isRequired: true,
description: `The session in which this image was created`,
},
invocation: {
description: `The prepared invocation that generated this image`,
properties: {
},
node: {
type: 'all-of',
description: `The node that created this image`,
contains: [{
type: 'NodeMetadata',
}],
},
},
} as const;

View File

@@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $NodeMetadata = {
description: `Node metadata model, used for validation of metadata.`,
properties: {
},
} as const;

View File

@@ -26,7 +26,7 @@ export const $RandomRangeInvocation = {
},
seed: {
type: 'number',
description: `The seed for the RNG`,
description: `The seed for the RNG, provide None or -1 for random`,
maximum: 2147483647,
},
},

View File

@@ -0,0 +1,252 @@
import json
import os
from copy import deepcopy
from PIL import Image, PngImagePlugin
from invokeai.app.invocations.generate import TextToImageInvocation
from invokeai.app.modules.metadata import InvokeAIMetadata, MetadataModule
good_metadata_dict = {
"session_id": "1",
"node": {
"id": "1",
"type": "txt2img",
"prompt": "dog",
"seed": 178785523,
"steps": 30,
"width": 512,
"height": 512,
"image": {"image_type": "results", "image_name": "1"},
"cfg_scale": 7.5,
"scheduler": "k_lms",
"seamless": False,
"model": "stable-diffusion-1.5",
"progress_images": True,
},
}
bad_metadata_dict_missing_session_id = deepcopy(good_metadata_dict)
bad_metadata_dict_missing_session_id["session_id"] = None
bad_metadata_dict_invalid_session_id = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_session_id["session_id"] = 123
bad_metadata_dict_missing_node = deepcopy(good_metadata_dict)
bad_metadata_dict_missing_node["node"] = None
bad_metadata_dict_invalid_node = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_node["node"] = 123
bad_metadata_dict_missing_node_id = deepcopy(good_metadata_dict)
del bad_metadata_dict_missing_node_id["node"]["id"]
bad_metadata_dict_invalid_node_id = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_node_id["node"]["id"] = 123
bad_metadata_dict_missing_node_type = deepcopy(good_metadata_dict)
del bad_metadata_dict_missing_node_type["node"]["type"]
bad_metadata_dict_invalid_node_type = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_node_type["node"]["type"] = 123
bad_metadata_dict_no_node_attrs = deepcopy(good_metadata_dict)
bad_metadata_dict_no_node_attrs["node"] = {}
bad_metadata_dict_array_attr = deepcopy(good_metadata_dict)
bad_metadata_dict_array_attr["node"]["seed"] = [1, 2, 3]
bad_metadata_dict_invalid_dict_attr = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_dict_attr["node"]["seed"] = {"a": 1}
bad_metadata_dict_missing_image_field_image_type = deepcopy(good_metadata_dict)
del bad_metadata_dict_missing_image_field_image_type["node"]["image"]["image_type"]
bad_metadata_dict_invalid_image_field_image_type = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_image_field_image_type["node"]["image"][
"image_type"
] = "bad image type"
bad_metadata_dict_invalid_latents_field_latents_name = deepcopy(good_metadata_dict)
bad_metadata_dict_invalid_latents_field_latents_name["node"]["latents"] = {
"latents_name": 123
}
def test_is_good_metadata_unchanged():
parsed_metadata = MetadataModule._parse_invokeai_metadata(good_metadata_dict)
assert good_metadata_dict == parsed_metadata.dict()
def test_bad_metadata_dict_missing_session_id():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_missing_session_id
)
assert bad_metadata_dict_missing_session_id == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_session_id():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_session_id
)
assert bad_metadata_dict_missing_session_id == parsed_metadata.dict()
def test_bad_metadata_dict_missing_node():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_missing_node
)
assert bad_metadata_dict_missing_node == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_node():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_node
)
assert bad_metadata_dict_missing_node == parsed_metadata.dict()
def test_bad_metadata_dict_missing_node_id():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_missing_node_id
)
assert bad_metadata_dict_missing_node_id == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_node_id():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_node_id
)
assert bad_metadata_dict_missing_node_id == parsed_metadata.dict()
def test_bad_metadata_dict_missing_node_type():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_missing_node_type
)
assert bad_metadata_dict_missing_node_type == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_node_type():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_node_type
)
assert bad_metadata_dict_missing_node_type == parsed_metadata.dict()
def test_bad_metadata_dict_no_node_attrs():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_no_node_attrs
)
assert bad_metadata_dict_missing_node == parsed_metadata.dict()
def test_bad_metadata_dict_array_attr():
expected = deepcopy(good_metadata_dict)
del expected["node"]["seed"]
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_array_attr
)
assert expected == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_dict_attr():
expected = deepcopy(good_metadata_dict)
del expected["node"]["seed"]
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_dict_attr
)
assert expected == parsed_metadata.dict()
def test_bad_metadata_dict_missing_image_field_image_type():
expected = deepcopy(good_metadata_dict)
del expected["node"]["image"]
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_missing_image_field_image_type
)
assert expected == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_image_field_image_type():
expected = deepcopy(good_metadata_dict)
del expected["node"]["image"]
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_image_field_image_type
)
assert expected == parsed_metadata.dict()
def test_bad_metadata_dict_invalid_latents_field_latents_name():
parsed_metadata = MetadataModule._parse_invokeai_metadata(
bad_metadata_dict_invalid_latents_field_latents_name
)
assert good_metadata_dict == parsed_metadata.dict()
def test_can_load_and_parse_invokeai_metadata(tmp_path):
raw_metadata = {"session_id": "123", "node": {"id": "456", "type": "test_type"}}
temp_image = Image.new("RGB", (512, 512))
temp_image_path = os.path.join(tmp_path, "test.png")
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", json.dumps(raw_metadata))
temp_image.save(temp_image_path, pnginfo=pnginfo)
image = Image.open(temp_image_path)
loaded_metadata = MetadataModule._load_metadata(image)
parsed_metadata = MetadataModule._parse_invokeai_metadata(loaded_metadata)
loaded_and_parsed_metadata = MetadataModule.get_metadata(image)
assert raw_metadata == loaded_metadata
assert raw_metadata == parsed_metadata.dict()
assert raw_metadata == loaded_and_parsed_metadata.dict()
def test_can_build_invokeai_metadata():
session_id = "123"
invocation = TextToImageInvocation(
id="456",
prompt="test",
seed=1,
steps=10,
width=512,
height=512,
cfg_scale=7.5,
scheduler="k_lms",
seamless=False,
model="test_mode",
progress_images=True,
)
metadata = MetadataModule.build_metadata(
session_id=session_id, invocation=invocation
)
expected_metadata_dict = {
"session_id": "123",
"node": {
"id": "456",
"type": "txt2img",
"prompt": "test",
"seed": 1,
"steps": 10,
"width": 512,
"height": 512,
"cfg_scale": 7.5,
"scheduler": "k_lms",
"seamless": False,
"model": "test_mode",
"progress_images": True,
},
}
assert type(metadata) is InvokeAIMetadata
assert expected_metadata_dict == metadata.dict()