mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(nodes): add metadata module + tests, thumbnails
- `MetadataModule` is stateless and needed in places where the `InvocationContext` is not available, so have not made it a `service` - Handles loading/parsing/building metadata, and creating png info objects - added tests for MetadataModule - Lifted thumbnail stuff to util
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.modules.metadata import ImageMetadata
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.metadata import ImageMetadata, InvokeAIMetadata
|
||||
from invokeai.app.modules.metadata import ImageMetadata, InvokeAIMetadata
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
|
||||
@@ -6,7 +6,12 @@ import numpy as np
|
||||
import numpy.random
|
||||
from pydantic import Field
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
@@ -41,12 +46,14 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: Optional[int] = Field(
|
||||
ge=0,
|
||||
le=np.iinfo(np.int32).max,
|
||||
description="The seed for the RNG",
|
||||
description="The seed for the RNG, provide None or -1 for random",
|
||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import InvokeAIMetadata
|
||||
from invokeai.app.modules.metadata import MetadataModule
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
@@ -58,9 +58,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.metadata import InvokeAIMetadata
|
||||
from invokeai.app.modules.metadata import MetadataModule
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
@@ -58,7 +58,10 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState,
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
@@ -72,7 +75,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
@@ -93,13 +98,14 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, generate_output.image, metadata)
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
@@ -123,8 +129,11 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
@@ -146,18 +155,20 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
@@ -172,10 +183,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
@@ -185,6 +195,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
@@ -200,7 +211,10 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
@@ -227,18 +241,20 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
@@ -254,9 +270,8 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.metadata import InvokeAIMetadata
|
||||
from invokeai.app.modules.metadata import MetadataModule
|
||||
|
||||
from ..models.image import ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
@@ -151,13 +151,8 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self.dict()
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
@@ -214,13 +209,8 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self.dict()
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
@@ -256,9 +246,10 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
@@ -292,9 +283,10 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
@@ -328,9 +320,10 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
@@ -369,9 +362,10 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.metadata import InvokeAIMetadata
|
||||
from invokeai.app.modules.metadata import MetadataModule
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
@@ -358,10 +358,10 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import InvokeAIMetadata
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from invokeai.app.modules.metadata import MetadataModule
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
@@ -46,10 +45,10 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import InvokeAIMetadata
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from invokeai.app.modules.metadata import MetadataModule
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
@@ -51,10 +49,10 @@ class UpscaleInvocation(BaseInvocation):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=context.graph_execution_state_id,
|
||||
invocation=self.dict()
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=context.graph_execution_state_id, invocation=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from typing import Any, Optional, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvokeAIMetadata(BaseModel):
|
||||
"""An image's InvokeAI-specific metadata"""
|
||||
|
||||
session_id: str = Field(description="The session that generated this image")
|
||||
invocation: dict = Field(
|
||||
default={}, description="The prepared invocation that generated this image"
|
||||
)
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
mode: str = Field(description="The color mode of the image")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
180
invokeai/app/modules/metadata.py
Normal file
180
invokeai/app/modules/metadata.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
StrictBool,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
ValidationError,
|
||||
root_validator,
|
||||
)
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
|
||||
|
||||
class MetadataImageField(BaseModel):
|
||||
"""A non-nullable version of ImageField"""
|
||||
|
||||
image_type: Literal[tuple([t.value for t in ImageType])] # type: ignore
|
||||
image_name: StrictStr
|
||||
|
||||
|
||||
class MetadataLatentsField(BaseModel):
|
||||
"""A non-nullable version of LatentsField"""
|
||||
|
||||
latents_name: StrictStr
|
||||
|
||||
|
||||
# Union of all valid metadata field types - use mostly strict types
|
||||
NodeMetadataFieldTypes = (
|
||||
StrictStr | StrictInt | float | StrictBool # we want to cast ints to floats here
|
||||
)
|
||||
|
||||
|
||||
class NodeMetadataField(BaseModel):
|
||||
"""Helper class used as a hack for arbitrary metadata field keys."""
|
||||
|
||||
__root__: Dict[StrictStr, NodeMetadataFieldTypes]
|
||||
|
||||
|
||||
# `extra=Extra.allow` allows this to model any potential node with `id` and `type` fields
|
||||
class NodeMetadata(BaseModel, extra=Extra.allow):
|
||||
"""Node metadata model, used for validation of metadata."""
|
||||
|
||||
@root_validator
|
||||
def validate_node_metadata(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Parses the node metadata, ignoring invalid values"""
|
||||
parsed: dict[str, Any] = {}
|
||||
|
||||
# Conditionally build the parsed metadata, silently skipping invalid values
|
||||
for name, value in values.items():
|
||||
# explicitly parse `id` and `type` as strings
|
||||
if name == "id":
|
||||
if type(value) is not str:
|
||||
continue
|
||||
parsed[name] = value
|
||||
elif name == "type":
|
||||
if type(value) is not str:
|
||||
continue
|
||||
parsed[name] = value
|
||||
else:
|
||||
try:
|
||||
if type(value) is dict:
|
||||
# we only allow certain dicts, else just ignore the value entirely
|
||||
if "image_name" in value or "image_type" in value:
|
||||
# parse as an ImageField
|
||||
parsed[name] = MetadataImageField.parse_obj(value)
|
||||
elif "latents_name" in value:
|
||||
# this is a LatentsField
|
||||
parsed[name] = MetadataLatentsField.parse_obj(value)
|
||||
else:
|
||||
# hack to get parse and validate arbitrary keys
|
||||
NodeMetadataField.parse_obj({name: value})
|
||||
parsed[name] = value
|
||||
except ValidationError:
|
||||
# TODO: do we want to somehow alert when metadata is not fully valid?
|
||||
continue
|
||||
return parsed
|
||||
|
||||
|
||||
class InvokeAIMetadata(BaseModel):
|
||||
session_id: Optional[StrictStr] = Field(
|
||||
description="The session in which this image was created"
|
||||
)
|
||||
node: Optional[NodeMetadata] = Field(description="The node that created this image")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_invokeai_metadata(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
parsed: dict[str, Any] = {}
|
||||
# Conditionally build the parsed metadata, silently skipping invalid values
|
||||
for name, value in values.items():
|
||||
if name == "session_id":
|
||||
if type(value) is not str:
|
||||
continue
|
||||
parsed[name] = value
|
||||
elif name == "node":
|
||||
try:
|
||||
p = NodeMetadata.parse_obj(value)
|
||||
# check for empty NodeMetadata object
|
||||
if len(p.dict().items()) == 0:
|
||||
continue
|
||||
except ValidationError:
|
||||
continue
|
||||
parsed[name] = value
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
mode: str = Field(description="The color mode of the image")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
|
||||
|
||||
class MetadataModule:
|
||||
"""Handles loading metadata from images and parsing it."""
|
||||
|
||||
# TODO: Support parsing old format metadata **hurk**
|
||||
|
||||
@staticmethod
|
||||
def _load_metadata(image: Image.Image, key="invokeai") -> Any:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
|
||||
raw_metadata = image.info.get(key)
|
||||
|
||||
# metadata should always be a dict
|
||||
if type(raw_metadata) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(raw_metadata)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
@staticmethod
|
||||
def _parse_invokeai_metadata(
|
||||
metadata: Any,
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Parses an object as InvokeAI metadata."""
|
||||
if type(metadata) is not dict:
|
||||
return None
|
||||
|
||||
parsed_metadata = InvokeAIMetadata.parse_obj(metadata)
|
||||
|
||||
return parsed_metadata
|
||||
|
||||
@staticmethod
|
||||
def get_metadata(image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
loaded_metadata = MetadataModule._load_metadata(image)
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(loaded_metadata)
|
||||
|
||||
return parsed_metadata
|
||||
|
||||
@staticmethod
|
||||
def build_metadata(
|
||||
session_id: StrictStr, invocation: BaseModel
|
||||
) -> InvokeAIMetadata:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
metadata = InvokeAIMetadata(
|
||||
session_id=session_id, node=NodeMetadata(**invocation.dict())
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def build_png_info(metadata: InvokeAIMetadata | None):
|
||||
png_info = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
png_info.add_text("invokeai", metadata.json())
|
||||
|
||||
return png_info
|
||||
@@ -14,9 +14,10 @@ import PIL.Image as PILImage
|
||||
from PIL import PngImagePlugin
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata, InvokeAIMetadata
|
||||
from invokeai.app.modules.metadata import ImageMetadata, InvokeAIMetadata, MetadataModule
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.get_timestamp import get_timestamp
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
@@ -113,11 +114,7 @@ class DiskImageStorage(ImageStorageBase):
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
|
||||
# TODO: handle old `sd-metadata` format
|
||||
invokeai_metadata = img.info.get("invokeai", None)
|
||||
|
||||
if invokeai_metadata is not None:
|
||||
invokeai_metadata = InvokeAIMetadata(**json.loads(invokeai_metadata))
|
||||
invokeai_metadata = MetadataModule.get_metadata(img)
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
@@ -188,22 +185,18 @@ class DiskImageStorage(ImageStorageBase):
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: InvokeAIMetadata | None = None) -> Tuple[str, str, int]:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
png_info = MetadataModule.build_png_info(metadata=metadata)
|
||||
|
||||
if metadata:
|
||||
info.add_text("invokeai", metadata.json())
|
||||
image.save(image_path, "PNG", pnginfo=png_info)
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=info)
|
||||
|
||||
thumbnail = image.copy()
|
||||
thumbnail.thumbnail(size=(256, 256))
|
||||
thumbnail.save(thumbnail_path, "WEBP")
|
||||
thumbnail_image = make_thumbnail(image)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
|
||||
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
|
||||
|
||||
def save_thumbnail(
|
||||
image: Image.Image,
|
||||
filename: str,
|
||||
image_type: ImageType,
|
||||
size: int = 256,
|
||||
) -> str:
|
||||
"""
|
||||
Saves a thumbnail of an image, returning its path.
|
||||
"""
|
||||
base_filename =
|
||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
return thumbnail_path
|
||||
|
||||
image_copy = image.copy()
|
||||
image_copy.thumbnail(size=(size, size))
|
||||
|
||||
image_copy.save(thumbnail_path, "WEBP")
|
||||
|
||||
return thumbnail_path
|
||||
15
invokeai/app/util/thumbnails.py
Normal file
15
invokeai/app/util/thumbnails.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_thumbnail_name(image_name: str) -> str:
|
||||
"""Formats given an image name, returns the appropriate thumbnail image name"""
|
||||
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
||||
return thumbnail_name
|
||||
|
||||
|
||||
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
|
||||
"""Makes a thumbnail from a PIL Image"""
|
||||
thumbnail = image.copy()
|
||||
thumbnail.thumbnail(size=(size, size))
|
||||
return thumbnail
|
||||
@@ -139,7 +139,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
});
|
||||
|
||||
const sessionId = image.metadata.invokeai?.session_id;
|
||||
const invocation = image.metadata.invokeai?.invocation;
|
||||
const node = image.metadata.invokeai?.node as Record<string, any>;
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { getUrl } = useGetUrl();
|
||||
@@ -170,105 +170,101 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
<ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
</Flex>
|
||||
{Object.keys(invocation).length > 0 ? (
|
||||
{node && Object.keys(node).length > 0 ? (
|
||||
<>
|
||||
{invocation.type && (
|
||||
<MetadataItem label="Invocation type" value={invocation.type} />
|
||||
{node.type && (
|
||||
<MetadataItem label="Invocation type" value={node.type} />
|
||||
)}
|
||||
{invocation.model && (
|
||||
<MetadataItem label="Model" value={invocation.model} />
|
||||
)}
|
||||
{invocation.prompt && (
|
||||
{node.model && <MetadataItem label="Model" value={node.model} />}
|
||||
{node.prompt && (
|
||||
<MetadataItem
|
||||
label="Prompt"
|
||||
labelPosition="top"
|
||||
value={
|
||||
typeof invocation.prompt === 'string'
|
||||
? invocation.prompt
|
||||
: promptToString(invocation.prompt)
|
||||
typeof node.prompt === 'string'
|
||||
? node.prompt
|
||||
: promptToString(node.prompt)
|
||||
}
|
||||
onClick={() => setBothPrompts(invocation.prompt)}
|
||||
onClick={() => setBothPrompts(node.prompt)}
|
||||
/>
|
||||
)}
|
||||
{invocation.seed !== undefined && (
|
||||
{node.seed !== undefined && (
|
||||
<MetadataItem
|
||||
label="Seed"
|
||||
value={invocation.seed}
|
||||
onClick={() => dispatch(setSeed(invocation.seed))}
|
||||
value={node.seed}
|
||||
onClick={() => dispatch(setSeed(node.seed))}
|
||||
/>
|
||||
)}
|
||||
{invocation.threshold !== undefined && (
|
||||
{node.threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label="Noise Threshold"
|
||||
value={invocation.threshold}
|
||||
onClick={() => dispatch(setThreshold(invocation.threshold))}
|
||||
value={node.threshold}
|
||||
onClick={() => dispatch(setThreshold(node.threshold))}
|
||||
/>
|
||||
)}
|
||||
{invocation.perlin !== undefined && (
|
||||
{node.perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label="Perlin Noise"
|
||||
value={invocation.perlin}
|
||||
onClick={() => dispatch(setPerlin(invocation.perlin))}
|
||||
value={node.perlin}
|
||||
onClick={() => dispatch(setPerlin(node.perlin))}
|
||||
/>
|
||||
)}
|
||||
{invocation.scheduler && (
|
||||
{node.scheduler && (
|
||||
<MetadataItem
|
||||
label="Sampler"
|
||||
value={invocation.scheduler}
|
||||
onClick={() => dispatch(setSampler(invocation.scheduler))}
|
||||
value={node.scheduler}
|
||||
onClick={() => dispatch(setSampler(node.scheduler))}
|
||||
/>
|
||||
)}
|
||||
{invocation.steps && (
|
||||
{node.steps && (
|
||||
<MetadataItem
|
||||
label="Steps"
|
||||
value={invocation.steps}
|
||||
onClick={() => dispatch(setSteps(invocation.steps))}
|
||||
value={node.steps}
|
||||
onClick={() => dispatch(setSteps(node.steps))}
|
||||
/>
|
||||
)}
|
||||
{invocation.cfg_scale !== undefined && (
|
||||
{node.cfg_scale !== undefined && (
|
||||
<MetadataItem
|
||||
label="CFG scale"
|
||||
value={invocation.cfg_scale}
|
||||
onClick={() => dispatch(setCfgScale(invocation.cfg_scale))}
|
||||
value={node.cfg_scale}
|
||||
onClick={() => dispatch(setCfgScale(node.cfg_scale))}
|
||||
/>
|
||||
)}
|
||||
{invocation.variations && invocation.variations.length > 0 && (
|
||||
{node.variations && node.variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="Seed-weight pairs"
|
||||
value={seedWeightsToString(invocation.variations)}
|
||||
value={seedWeightsToString(node.variations)}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setSeedWeights(seedWeightsToString(invocation.variations))
|
||||
)
|
||||
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{invocation.seamless && (
|
||||
{node.seamless && (
|
||||
<MetadataItem
|
||||
label="Seamless"
|
||||
value={invocation.seamless}
|
||||
onClick={() => dispatch(setSeamless(invocation.seamless))}
|
||||
value={node.seamless}
|
||||
onClick={() => dispatch(setSeamless(node.seamless))}
|
||||
/>
|
||||
)}
|
||||
{invocation.hires_fix && (
|
||||
{node.hires_fix && (
|
||||
<MetadataItem
|
||||
label="High Resolution Optimization"
|
||||
value={invocation.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(invocation.hires_fix))}
|
||||
value={node.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(node.hires_fix))}
|
||||
/>
|
||||
)}
|
||||
{invocation.width && (
|
||||
{node.width && (
|
||||
<MetadataItem
|
||||
label="Width"
|
||||
value={invocation.width}
|
||||
onClick={() => dispatch(setWidth(invocation.width))}
|
||||
value={node.width}
|
||||
onClick={() => dispatch(setWidth(node.width))}
|
||||
/>
|
||||
)}
|
||||
{invocation.height && (
|
||||
{node.height && (
|
||||
<MetadataItem
|
||||
label="Height"
|
||||
value={invocation.height}
|
||||
onClick={() => dispatch(setHeight(invocation.height))}
|
||||
value={node.height}
|
||||
onClick={() => dispatch(setHeight(node.height))}
|
||||
/>
|
||||
)}
|
||||
{/* {init_image_path && (
|
||||
@@ -279,20 +275,18 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
{invocation.strength && (
|
||||
{node.strength && (
|
||||
<MetadataItem
|
||||
label="Image to image strength"
|
||||
value={invocation.strength}
|
||||
onClick={() => dispatch(setImg2imgStrength(invocation.strength))}
|
||||
value={node.strength}
|
||||
onClick={() => dispatch(setImg2imgStrength(node.strength))}
|
||||
/>
|
||||
)}
|
||||
{invocation.fit && (
|
||||
{node.fit && (
|
||||
<MetadataItem
|
||||
label="Image to image fit"
|
||||
value={invocation.fit}
|
||||
onClick={() =>
|
||||
dispatch(setShouldFitToWidthHeight(invocation.fit))
|
||||
}
|
||||
value={node.fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -47,6 +47,7 @@ export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||
export type { MaskOutput } from './models/MaskOutput';
|
||||
export type { ModelsList } from './models/ModelsList';
|
||||
export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
||||
export type { NodeMetadata } from './models/NodeMetadata';
|
||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||
export type { NoiseOutput } from './models/NoiseOutput';
|
||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||
@@ -106,6 +107,7 @@ export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
|
||||
export { $MaskOutput } from './schemas/$MaskOutput';
|
||||
export { $ModelsList } from './schemas/$ModelsList';
|
||||
export { $MultiplyInvocation } from './schemas/$MultiplyInvocation';
|
||||
export { $NodeMetadata } from './schemas/$NodeMetadata';
|
||||
export { $NoiseInvocation } from './schemas/$NoiseInvocation';
|
||||
export { $NoiseOutput } from './schemas/$NoiseOutput';
|
||||
export { $PaginatedResults_GraphExecutionState_ } from './schemas/$PaginatedResults_GraphExecutionState_';
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An image's InvokeAI-specific metadata
|
||||
*/
|
||||
import type { NodeMetadata } from './NodeMetadata';
|
||||
|
||||
export type InvokeAIMetadata = {
|
||||
/**
|
||||
* The session that generated this image
|
||||
* The session in which this image was created
|
||||
*/
|
||||
session_id: string;
|
||||
session_id?: string;
|
||||
/**
|
||||
* The prepared invocation that generated this image
|
||||
* The node that created this image
|
||||
*/
|
||||
invocation?: any;
|
||||
node?: NodeMetadata;
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Node metadata model, used for validation of metadata.
|
||||
*/
|
||||
export type NodeMetadata = {
|
||||
};
|
||||
|
||||
@@ -24,7 +24,7 @@ export type RandomRangeInvocation = {
|
||||
*/
|
||||
size?: number;
|
||||
/**
|
||||
* The seed for the RNG
|
||||
* The seed for the RNG, provide None or -1 for random
|
||||
*/
|
||||
seed?: number;
|
||||
};
|
||||
|
||||
@@ -2,17 +2,17 @@
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $InvokeAIMetadata = {
|
||||
description: `An image's InvokeAI-specific metadata`,
|
||||
properties: {
|
||||
session_id: {
|
||||
type: 'string',
|
||||
description: `The session that generated this image`,
|
||||
isRequired: true,
|
||||
description: `The session in which this image was created`,
|
||||
},
|
||||
invocation: {
|
||||
description: `The prepared invocation that generated this image`,
|
||||
properties: {
|
||||
},
|
||||
node: {
|
||||
type: 'all-of',
|
||||
description: `The node that created this image`,
|
||||
contains: [{
|
||||
type: 'NodeMetadata',
|
||||
}],
|
||||
},
|
||||
},
|
||||
} as const;
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
export const $NodeMetadata = {
|
||||
description: `Node metadata model, used for validation of metadata.`,
|
||||
properties: {
|
||||
},
|
||||
} as const;
|
||||
@@ -26,7 +26,7 @@ export const $RandomRangeInvocation = {
|
||||
},
|
||||
seed: {
|
||||
type: 'number',
|
||||
description: `The seed for the RNG`,
|
||||
description: `The seed for the RNG, provide None or -1 for random`,
|
||||
maximum: 2147483647,
|
||||
},
|
||||
},
|
||||
|
||||
252
tests/nodes/test_metadata.py
Normal file
252
tests/nodes/test_metadata.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from copy import deepcopy
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from invokeai.app.invocations.generate import TextToImageInvocation
|
||||
from invokeai.app.modules.metadata import InvokeAIMetadata, MetadataModule
|
||||
|
||||
good_metadata_dict = {
|
||||
"session_id": "1",
|
||||
"node": {
|
||||
"id": "1",
|
||||
"type": "txt2img",
|
||||
"prompt": "dog",
|
||||
"seed": 178785523,
|
||||
"steps": 30,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"image": {"image_type": "results", "image_name": "1"},
|
||||
"cfg_scale": 7.5,
|
||||
"scheduler": "k_lms",
|
||||
"seamless": False,
|
||||
"model": "stable-diffusion-1.5",
|
||||
"progress_images": True,
|
||||
},
|
||||
}
|
||||
|
||||
bad_metadata_dict_missing_session_id = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_missing_session_id["session_id"] = None
|
||||
|
||||
bad_metadata_dict_invalid_session_id = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_session_id["session_id"] = 123
|
||||
|
||||
bad_metadata_dict_missing_node = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_missing_node["node"] = None
|
||||
|
||||
bad_metadata_dict_invalid_node = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_node["node"] = 123
|
||||
|
||||
bad_metadata_dict_missing_node_id = deepcopy(good_metadata_dict)
|
||||
del bad_metadata_dict_missing_node_id["node"]["id"]
|
||||
|
||||
bad_metadata_dict_invalid_node_id = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_node_id["node"]["id"] = 123
|
||||
|
||||
bad_metadata_dict_missing_node_type = deepcopy(good_metadata_dict)
|
||||
del bad_metadata_dict_missing_node_type["node"]["type"]
|
||||
|
||||
bad_metadata_dict_invalid_node_type = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_node_type["node"]["type"] = 123
|
||||
|
||||
bad_metadata_dict_no_node_attrs = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_no_node_attrs["node"] = {}
|
||||
|
||||
bad_metadata_dict_array_attr = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_array_attr["node"]["seed"] = [1, 2, 3]
|
||||
|
||||
bad_metadata_dict_invalid_dict_attr = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_dict_attr["node"]["seed"] = {"a": 1}
|
||||
|
||||
bad_metadata_dict_missing_image_field_image_type = deepcopy(good_metadata_dict)
|
||||
del bad_metadata_dict_missing_image_field_image_type["node"]["image"]["image_type"]
|
||||
|
||||
bad_metadata_dict_invalid_image_field_image_type = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_image_field_image_type["node"]["image"][
|
||||
"image_type"
|
||||
] = "bad image type"
|
||||
|
||||
bad_metadata_dict_invalid_latents_field_latents_name = deepcopy(good_metadata_dict)
|
||||
bad_metadata_dict_invalid_latents_field_latents_name["node"]["latents"] = {
|
||||
"latents_name": 123
|
||||
}
|
||||
|
||||
|
||||
def test_is_good_metadata_unchanged():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(good_metadata_dict)
|
||||
assert good_metadata_dict == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_missing_session_id():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_missing_session_id
|
||||
)
|
||||
assert bad_metadata_dict_missing_session_id == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_session_id():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_session_id
|
||||
)
|
||||
assert bad_metadata_dict_missing_session_id == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_missing_node():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_missing_node
|
||||
)
|
||||
assert bad_metadata_dict_missing_node == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_node():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_node
|
||||
)
|
||||
assert bad_metadata_dict_missing_node == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_missing_node_id():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_missing_node_id
|
||||
)
|
||||
assert bad_metadata_dict_missing_node_id == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_node_id():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_node_id
|
||||
)
|
||||
assert bad_metadata_dict_missing_node_id == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_missing_node_type():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_missing_node_type
|
||||
)
|
||||
assert bad_metadata_dict_missing_node_type == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_node_type():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_node_type
|
||||
)
|
||||
assert bad_metadata_dict_missing_node_type == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_no_node_attrs():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_no_node_attrs
|
||||
)
|
||||
assert bad_metadata_dict_missing_node == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_array_attr():
|
||||
expected = deepcopy(good_metadata_dict)
|
||||
del expected["node"]["seed"]
|
||||
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_array_attr
|
||||
)
|
||||
assert expected == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_dict_attr():
|
||||
expected = deepcopy(good_metadata_dict)
|
||||
del expected["node"]["seed"]
|
||||
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_dict_attr
|
||||
)
|
||||
assert expected == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_missing_image_field_image_type():
|
||||
expected = deepcopy(good_metadata_dict)
|
||||
del expected["node"]["image"]
|
||||
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_missing_image_field_image_type
|
||||
)
|
||||
assert expected == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_image_field_image_type():
|
||||
expected = deepcopy(good_metadata_dict)
|
||||
del expected["node"]["image"]
|
||||
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_image_field_image_type
|
||||
)
|
||||
assert expected == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_bad_metadata_dict_invalid_latents_field_latents_name():
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(
|
||||
bad_metadata_dict_invalid_latents_field_latents_name
|
||||
)
|
||||
assert good_metadata_dict == parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_can_load_and_parse_invokeai_metadata(tmp_path):
|
||||
raw_metadata = {"session_id": "123", "node": {"id": "456", "type": "test_type"}}
|
||||
|
||||
temp_image = Image.new("RGB", (512, 512))
|
||||
temp_image_path = os.path.join(tmp_path, "test.png")
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", json.dumps(raw_metadata))
|
||||
|
||||
temp_image.save(temp_image_path, pnginfo=pnginfo)
|
||||
|
||||
image = Image.open(temp_image_path)
|
||||
|
||||
loaded_metadata = MetadataModule._load_metadata(image)
|
||||
parsed_metadata = MetadataModule._parse_invokeai_metadata(loaded_metadata)
|
||||
loaded_and_parsed_metadata = MetadataModule.get_metadata(image)
|
||||
|
||||
assert raw_metadata == loaded_metadata
|
||||
assert raw_metadata == parsed_metadata.dict()
|
||||
assert raw_metadata == loaded_and_parsed_metadata.dict()
|
||||
|
||||
|
||||
def test_can_build_invokeai_metadata():
|
||||
session_id = "123"
|
||||
invocation = TextToImageInvocation(
|
||||
id="456",
|
||||
prompt="test",
|
||||
seed=1,
|
||||
steps=10,
|
||||
width=512,
|
||||
height=512,
|
||||
cfg_scale=7.5,
|
||||
scheduler="k_lms",
|
||||
seamless=False,
|
||||
model="test_mode",
|
||||
progress_images=True,
|
||||
)
|
||||
|
||||
metadata = MetadataModule.build_metadata(
|
||||
session_id=session_id, invocation=invocation
|
||||
)
|
||||
|
||||
expected_metadata_dict = {
|
||||
"session_id": "123",
|
||||
"node": {
|
||||
"id": "456",
|
||||
"type": "txt2img",
|
||||
"prompt": "test",
|
||||
"seed": 1,
|
||||
"steps": 10,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"cfg_scale": 7.5,
|
||||
"scheduler": "k_lms",
|
||||
"seamless": False,
|
||||
"model": "test_mode",
|
||||
"progress_images": True,
|
||||
},
|
||||
}
|
||||
|
||||
assert type(metadata) is InvokeAIMetadata
|
||||
assert expected_metadata_dict == metadata.dict()
|
||||
Reference in New Issue
Block a user