feat(nodes): move metadata parsing to frontend

This commit is contained in:
psychedelicious
2023-04-21 17:46:41 +10:00
parent 688d3a9453
commit be0a033b90
39 changed files with 639 additions and 518 deletions

View File

@@ -11,7 +11,6 @@ class ImageResponseMetadata(BaseModel):
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
mode: str = Field(description="The color mode of the image")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)

View File

@@ -3,6 +3,7 @@ import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import HTTPException, Path, Query, Request, UploadFile
@@ -74,6 +75,7 @@ async def upload_image(
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
except:
@@ -81,17 +83,13 @@ async def upload_image(
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
)
# TODO: handle old `sd-metadata` style metadata
invokeai_metadata = img.info.get("invokeai", None)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
if invokeai_metadata is not None:
invokeai_metadata = InvokeAIMetadata(**json.loads(invokeai_metadata))
# TODO: should creation of this object should happen elsewhere?
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
@@ -101,7 +99,6 @@ async def upload_image(
created=ctime,
width=img.width,
height=img.height,
mode=img.mode,
invokeai=invokeai_metadata,
),
)

View File

@@ -65,7 +65,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
invocation_dict=self.dict(),
node=self.dict(),
source_node_id=source_node_id,
)
@@ -136,7 +136,7 @@ class ImageToImageInvocation(TextToImageInvocation):
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
invocation_dict=self.dict(),
node=self.dict(),
source_node_id=source_node_id,
)
@@ -218,7 +218,7 @@ class InpaintInvocation(ImageToImageInvocation):
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
invocation_dict=self.dict(),
node=self.dict(),
source_node_id=source_node_id,
)

View File

@@ -177,7 +177,7 @@ class TextToLatentsInvocation(BaseInvocation):
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
invocation_dict=self.dict(),
node=self.dict(),
source_node_id=source_node_id,
)

View File

@@ -25,7 +25,7 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_dict: dict,
node: dict,
source_node_id: str,
progress_image: ProgressImage | None,
step: int,
@@ -36,7 +36,7 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
step=step,
@@ -48,16 +48,15 @@ class EventServiceBase:
self,
graph_execution_state_id: str,
result: dict,
invocation_dict: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has completed"""
print(result)
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
node=node,
source_node_id=source_node_id,
result=result,
),
@@ -66,7 +65,7 @@ class EventServiceBase:
def emit_invocation_error(
self,
graph_execution_state_id: str,
invocation_dict: dict,
node: dict,
source_node_id: str,
error: str,
) -> None:
@@ -75,21 +74,21 @@ class EventServiceBase:
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
node=node,
source_node_id=source_node_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_dict: dict, source_node_id: str
self, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation=invocation_dict,
node=node,
source_node_id=source_node_id,
),
)

View File

@@ -14,14 +14,12 @@ from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_pnginfo,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from invokeai.backend.image_util import PngWriter
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@@ -135,7 +133,6 @@ class DiskImageStorage(ImageStorageBase):
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
mode=img.mode,
invokeai=invokeai_metadata,
),
)
@@ -194,8 +191,13 @@ class DiskImageStorage(ImageStorageBase):
metadata: InvokeAIMetadata | None = None,
) -> Tuple[str, str, int]:
image_path = self.get_path(image_type, image_name)
pnginfo = build_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)

View File

@@ -20,6 +20,7 @@ class MetadataLatentsField(TypedDict):
latents_name: str
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
]
@@ -32,7 +33,10 @@ class InvokeAIMetadata(TypedDict, total=False):
node: Optional[NodeMetadata]
def build_pnginfo(metadata: InvokeAIMetadata | None) -> PngImagePlugin.PngInfo:
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
@@ -41,142 +45,6 @@ def build_pnginfo(metadata: InvokeAIMetadata | None) -> PngImagePlugin.PngInfo:
return pnginfo
def parse_image_field(image_field: dict[str, Any]) -> dict[str, Any] | None:
"""Parses an object as a MetadataImageField"""
# Must be a dict
if type(image_field) is not dict:
return None
# An ImageField must have both `image_name` and `image_type`
if not ("image_name" in image_field and "image_type" in image_field):
return None
# An ImageField's `image_type` must be one of the allowed values
if not is_image_type(image_field["image_type"]):
return None
# An ImageField's `image_name` must be a string
if type(image_field["image_name"]) is not str:
return None
parsed = {
"image_type": image_field["image_type"],
"image_name": image_field["image_name"],
}
return parsed
def parse_latents_field(latents_field: dict[str, Any]) -> dict[str, Any] | None:
"""Parses an object as a MetadataLatentsField"""
# Must be a dict
if type(latents_field) is not dict:
return None
# A LatentsField must have a `latents_name`
if not ("latents_name" in latents_field):
return None
# A LatentsField's `latents_name` must be a string
if type(latents_field["latents_name"]) is not str:
return None
parsed = {
"latents_name": latents_field["latents_name"],
}
return parsed
def parse_node_metadata(node_metadata: Any) -> NodeMetadata | None:
"""Parses node metadata, silently skipping invalid entries"""
# Must be a dict
if type(node_metadata) is not dict:
return None
# Must have attributes
if len(node_metadata.items()) == 0:
return None
parsed: dict[str, Any] = {}
# Conditionally build the parsed metadata, silently skipping invalid values
for name, value in node_metadata.items():
value_type = type(value)
# explicitly parse `id` and `type` as strings
if name == "id":
if value_type is not str:
continue
parsed[name] = value
continue
if name == "type":
if value_type is not str:
continue
parsed[name] = value
continue
# we only allow ImageField and ImageType as dicts
if value_type is dict:
if "image_name" in value or "image_type" in value:
# parse as an ImageField
image_field = parse_image_field(value)
if image_field is not None:
parsed[name] = image_field
continue
if "latents_name" in value:
# parse as a LatentsField
latents_field = parse_latents_field(value)
if latents_field is not None:
parsed[name] = latents_field
continue
# other allowed primitive values
if (
value_type is str
or value_type is int
or value_type is float
or value_type is bool
):
parsed[name] = value
continue
return parsed
def parse_invokeai_metadata(
invokeai_metadata: dict[str, Any]
) -> InvokeAIMetadata | None:
"""Parse the InvokeAI metadata format, silently skipping invalid entries"""
# Must be a dict
if type(invokeai_metadata) is not dict:
return None
# Must have attributes
if len(invokeai_metadata.items()) == 0:
return None
parsed: InvokeAIMetadata = {}
for name, value in invokeai_metadata.items():
if name == "session_id":
if type(value) is str:
parsed[name] = value
continue
if name == "node":
node_metadata = parse_node_metadata(value)
if node_metadata is not None:
parsed[name] = node_metadata
continue
return parsed
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
@@ -192,41 +60,36 @@ class MetadataServiceBase(ABC):
class PngMetadataService(MetadataServiceBase):
"""Handles loading metadata from images and parsing it."""
"""Handles loading and building metadata for images."""
# TODO: Support parsing old format metadata **hurk**
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image, key="invokeai") -> Any:
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
raw_metadata = image.info.get(key)
try:
info = image.info.get("invokeai")
# metadata should always be a dict
if type(raw_metadata) is not str:
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
loaded_metadata = json.loads(raw_metadata)
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def _parse_invokeai_metadata(
self,
metadata: Any,
) -> InvokeAIMetadata | None:
"""Parses an object as InvokeAI metadata."""
if type(metadata) is not dict:
return None
parsed_metadata = parse_invokeai_metadata(metadata)
return parsed_metadata
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
loaded_metadata = self._load_metadata(image)
parsed_metadata = self._parse_invokeai_metadata(loaded_metadata)
return parsed_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())

View File

@@ -49,7 +49,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
node=invocation.dict(),
source_node_id=source_node_id
)
@@ -79,7 +79,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
node=invocation.dict(),
source_node_id=source_node_id,
result=outputs.dict(),
)
@@ -104,7 +104,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_dict=invocation.dict(),
node=invocation.dict(),
source_node_id=source_node_id,
error=error,
)

View File

@@ -9,7 +9,7 @@ from ...backend.stable_diffusion import PipelineIntermediateState
def stable_diffusion_step_callback(
context: InvocationContext,
intermediate_state: PipelineIntermediateState,
invocation_dict: dict,
node: dict,
source_node_id: str,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
@@ -47,9 +47,9 @@ def stable_diffusion_step_callback(
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
invocation_dict=invocation_dict,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
total_steps=invocation_dict["steps"],
total_steps=node["steps"],
)

View File

@@ -0,0 +1,119 @@
/**
* PARTIAL ZOD IMPLEMENTATION
*
* doesn't work well bc like most validators, zod is not built to skip invalid values.
* it mostly works but just seems clearer and simpler to manually parse for now.
*
* in the future it would be really nice if we could use zod for some things:
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
*/
// import { z } from 'zod';
// const zMetadataStringField = z.string();
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
// const zMetadataIntegerField = z.number().int();
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
// const zMetadataFloatField = z.number();
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
// const zMetadataBooleanField = z.boolean();
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
// const zMetadataImageField = z.object({
// image_type: z.union([
// z.literal('results'),
// z.literal('uploads'),
// z.literal('intermediates'),
// ]),
// image_name: z.string().min(1),
// });
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
// const zMetadataLatentsField = z.object({
// latents_name: z.string().min(1),
// });
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
// /**
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
// */
// const zAnyMetadataField = z.any().transform((val, ctx) => {
// // Grab the field name from the path
// const fieldName = String(ctx.path[ctx.path.length - 1]);
// // `id` and `type` must be strings if they exist
// if (['id', 'type'].includes(fieldName)) {
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
// if (reservedStringPropertyResult.success) {
// return reservedStringPropertyResult.data;
// }
// return;
// }
// // Parse the rest of the fields, only returning the data if the parsing is successful
// const stringFieldResult = zMetadataStringField.safeParse(val);
// if (stringFieldResult.success) {
// return stringFieldResult.data;
// }
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
// if (integerFieldResult.success) {
// return integerFieldResult.data;
// }
// const floatFieldResult = zMetadataFloatField.safeParse(val);
// if (floatFieldResult.success) {
// return floatFieldResult.data;
// }
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
// if (booleanFieldResult.success) {
// return booleanFieldResult.data;
// }
// const imageFieldResult = zMetadataImageField.safeParse(val);
// if (imageFieldResult.success) {
// return imageFieldResult.data;
// }
// const latentsFieldResult = zMetadataImageField.safeParse(val);
// if (latentsFieldResult.success) {
// return latentsFieldResult.data;
// }
// });
// /**
// * The node metadata schema.
// */
// const zNodeMetadata = z.object({
// session_id: z.string().min(1).optional(),
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
// });
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
// const zMetadata = z.object({
// invokeai: zNodeMetadata.optional(),
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
// });
// export type Metadata = z.infer<typeof zMetadata>;
// export const parseMetadata = (
// metadata: Record<string, any>
// ): Metadata | undefined => {
// const result = zMetadata.safeParse(metadata);
// if (!result.success) {
// console.log(result.error.issues);
// return;
// }
// return result.data;
// };
export default {};

View File

@@ -0,0 +1,169 @@
import { forEach, size } from 'lodash';
import { ImageField, LatentsField } from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
const NUMBER_TYPESTRING = '[object Number]';
const BOOLEAN_TYPESTRING = '[object Boolean]';
const ARRAY_TYPESTRING = '[object Array]';
const isObject = (obj: unknown): obj is Record<string | number, any> =>
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
const isString = (obj: unknown): obj is string =>
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
const isNumber = (obj: unknown): obj is number =>
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
const isBoolean = (obj: unknown): obj is boolean =>
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
const isArray = (obj: unknown): obj is Array<any> =>
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
const parseImageField = (imageField: unknown): ImageField | undefined => {
// Must be an object
if (!isObject(imageField)) {
return;
}
// An ImageField must have both `image_name` and `image_type`
if (!('image_name' in imageField && 'image_type' in imageField)) {
return;
}
// An ImageField's `image_type` must be one of the allowed values
if (
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
) {
return;
}
// An ImageField's `image_name` must be a string
if (typeof imageField.image_name !== 'string') {
return;
}
// Build a valid ImageField
return {
image_type: imageField.image_type,
image_name: imageField.image_name,
};
};
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
// Must be an object
if (!isObject(latentsField)) {
return;
}
// A LatentsField must have a `latents_name`
if (!('latents_name' in latentsField)) {
return;
}
// A LatentsField's `latents_name` must be a string
if (typeof latentsField.latents_name !== 'string') {
return;
}
// Build a valid LatentsField
return {
latents_name: latentsField.latents_name,
};
};
type NodeMetadata = {
[key: string]: string | number | boolean | ImageField | LatentsField;
};
type InvokeAIMetadata = {
session_id?: string;
node?: NodeMetadata;
};
export const parseNodeMetadata = (
nodeMetadata: Record<string | number, any>
): NodeMetadata | undefined => {
if (!isObject(nodeMetadata)) {
return;
}
const parsed: NodeMetadata = {};
forEach(nodeMetadata, (nodeItem, nodeKey) => {
// `id` and `type` must be strings if they are present
if (['id', 'type'].includes(nodeKey)) {
if (isString(nodeItem)) {
parsed[nodeKey] = nodeItem;
}
return;
}
// the only valid object types are ImageField and LatentsField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
if (imageField) {
parsed[nodeKey] = imageField;
}
return;
}
if ('latents_name' in nodeItem) {
const latentsField = parseLatentsField(nodeItem);
if (latentsField) {
parsed[nodeKey] = latentsField;
}
return;
}
}
// otherwise we accept any string, number or boolean
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
parsed[nodeKey] = nodeItem;
return;
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};
export const parseInvokeAIMetadata = (
metadata: Record<string | number, any> | undefined
): InvokeAIMetadata | undefined => {
if (metadata === undefined) {
return;
}
if (!isObject(metadata)) {
return;
}
const parsed: InvokeAIMetadata = {};
forEach(metadata, (item, key) => {
if (key === 'session_id' && isString(item)) {
parsed['session_id'] = item;
}
if (key === 'node' && isObject(item)) {
const nodeMetadata = parseNodeMetadata(item);
if (nodeMetadata) {
parsed['node'] = nodeMetadata;
}
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};

View File

@@ -192,21 +192,21 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem
label="Seed"
value={node.seed}
onClick={() => dispatch(setSeed(node.seed))}
onClick={() => dispatch(setSeed(Number(node.seed)))}
/>
)}
{node.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={node.threshold}
onClick={() => dispatch(setThreshold(node.threshold))}
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
/>
)}
{node.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={node.perlin}
onClick={() => dispatch(setPerlin(node.perlin))}
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
/>
)}
{node.scheduler && (
@@ -220,14 +220,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem
label="Steps"
value={node.steps}
onClick={() => dispatch(setSteps(node.steps))}
onClick={() => dispatch(setSteps(Number(node.steps)))}
/>
)}
{node.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={node.cfg_scale}
onClick={() => dispatch(setCfgScale(node.cfg_scale))}
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
/>
)}
{node.variations && node.variations.length > 0 && (
@@ -257,14 +257,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem
label="Width"
value={node.width}
onClick={() => dispatch(setWidth(node.width))}
onClick={() => dispatch(setWidth(Number(node.width)))}
/>
)}
{node.height && (
<MetadataItem
label="Height"
value={node.height}
onClick={() => dispatch(setHeight(node.height))}
onClick={() => dispatch(setHeight(Number(node.height)))}
/>
)}
{/* {init_image_path && (
@@ -279,7 +279,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
<MetadataItem
label="Image to image strength"
value={node.strength}
onClick={() => dispatch(setImg2imgStrength(node.strength))}
onClick={() =>
dispatch(setImg2imgStrength(Number(node.strength)))
}
/>
)}
{node.fit && (

View File

@@ -36,10 +36,6 @@ type AdditionalResultsState = {
nextPage: number; // the next page to request
};
// export type ResultsState = ReturnType<
// typeof resultsAdapter.getInitialState<AdditionalResultsState>
// >;
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
@@ -97,7 +93,7 @@ const resultsSlice = createSlice({
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, invocation, graph_execution_state_id } = data;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
@@ -115,10 +111,9 @@ const resultsSlice = createSlice({
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
mode: result.mode,
invokeai: {
session_id: graph_execution_state_id,
invocation,
...(node ? { node } : {}),
},
},
};

View File

@@ -66,16 +66,8 @@ const uploadsSlice = createSlice({
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const { image_name, image_url, image_type, metadata, thumbnail_url } =
response;
const uploadedImage: Image = {
name: image_name,
url: image_url,
thumbnail: thumbnail_url,
type: 'uploads',
metadata,
};
const uploadedImage = deserializeImageResponse(response);
uploadsAdapter.addOne(state, uploadedImage);
});

View File

@@ -25,9 +25,9 @@ export type { GraphInvocation } from './models/GraphInvocation';
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
export type { HTTPValidationError } from './models/HTTPValidationError';
export type { ImageField } from './models/ImageField';
export type { ImageMetadata } from './models/ImageMetadata';
export type { ImageOutput } from './models/ImageOutput';
export type { ImageResponse } from './models/ImageResponse';
export type { ImageResponseMetadata } from './models/ImageResponseMetadata';
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
export type { ImageType } from './models/ImageType';
export type { InpaintInvocation } from './models/InpaintInvocation';
@@ -45,9 +45,10 @@ export type { LerpInvocation } from './models/LerpInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { MetadataImageField } from './models/MetadataImageField';
export type { MetadataLatentsField } from './models/MetadataLatentsField';
export type { ModelsList } from './models/ModelsList';
export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NodeMetadata } from './models/NodeMetadata';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
@@ -85,9 +86,9 @@ export { $GraphInvocation } from './schemas/$GraphInvocation';
export { $GraphInvocationOutput } from './schemas/$GraphInvocationOutput';
export { $HTTPValidationError } from './schemas/$HTTPValidationError';
export { $ImageField } from './schemas/$ImageField';
export { $ImageMetadata } from './schemas/$ImageMetadata';
export { $ImageOutput } from './schemas/$ImageOutput';
export { $ImageResponse } from './schemas/$ImageResponse';
export { $ImageResponseMetadata } from './schemas/$ImageResponseMetadata';
export { $ImageToImageInvocation } from './schemas/$ImageToImageInvocation';
export { $ImageType } from './schemas/$ImageType';
export { $InpaintInvocation } from './schemas/$InpaintInvocation';
@@ -105,9 +106,10 @@ export { $LerpInvocation } from './schemas/$LerpInvocation';
export { $LoadImageInvocation } from './schemas/$LoadImageInvocation';
export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
export { $MaskOutput } from './schemas/$MaskOutput';
export { $MetadataImageField } from './schemas/$MetadataImageField';
export { $MetadataLatentsField } from './schemas/$MetadataLatentsField';
export { $ModelsList } from './schemas/$ModelsList';
export { $MultiplyInvocation } from './schemas/$MultiplyInvocation';
export { $NodeMetadata } from './schemas/$NodeMetadata';
export { $NoiseInvocation } from './schemas/$NoiseInvocation';
export { $NoiseOutput } from './schemas/$NoiseOutput';
export { $PaginatedResults_GraphExecutionState_ } from './schemas/$PaginatedResults_GraphExecutionState_';

View File

@@ -21,9 +21,5 @@ export type ImageOutput = {
* The height of the image in pixels
*/
height: number;
/**
* The image mode (ie pixel format)
*/
mode: string;
};

View File

@@ -2,7 +2,7 @@
/* tslint:disable */
/* eslint-disable */
import type { ImageMetadata } from './ImageMetadata';
import type { ImageResponseMetadata } from './ImageResponseMetadata';
import type { ImageType } from './ImageType';
/**
@@ -28,6 +28,6 @@ export type ImageResponse = {
/**
* The image's metadata
*/
metadata: ImageMetadata;
metadata: ImageResponseMetadata;
};

View File

@@ -7,7 +7,7 @@ import type { InvokeAIMetadata } from './InvokeAIMetadata';
/**
* An image's metadata. Used only in HTTP responses.
*/
export type ImageMetadata = {
export type ImageResponseMetadata = {
/**
* The creation timestamp of the image
*/
@@ -20,10 +20,6 @@ export type ImageMetadata = {
* The height of the image in pixels
*/
height: number;
/**
* The color mode of the image
*/
mode: string;
/**
* The image's InvokeAI-specific metadata
*/

View File

@@ -2,16 +2,11 @@
/* tslint:disable */
/* eslint-disable */
import type { NodeMetadata } from './NodeMetadata';
import type { MetadataImageField } from './MetadataImageField';
import type { MetadataLatentsField } from './MetadataLatentsField';
export type InvokeAIMetadata = {
/**
* The session in which this image was created
*/
session_id?: string;
/**
* The node that created this image
*/
node?: NodeMetadata;
node?: Record<string, (string | number | boolean | MetadataImageField | MetadataLatentsField)>;
};

View File

@@ -9,6 +9,6 @@ export type LatentsField = {
/**
* The name of the latents
*/
latents_name?: string;
latents_name: string;
};

View File

@@ -0,0 +1,11 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageType } from './ImageType';
export type MetadataImageField = {
image_type: ImageType;
image_name: string;
};

View File

@@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type MetadataLatentsField = {
latents_name: string;
};

View File

@@ -1,10 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Node metadata model, used for validation of metadata.
*/
export type NodeMetadata = {
};

View File

@@ -24,7 +24,7 @@ export type RandomRangeInvocation = {
*/
size?: number;
/**
* The seed for the RNG, provide None or -1 for random
* The seed for the RNG
*/
seed?: number;
};

View File

@@ -26,10 +26,5 @@ export const $ImageOutput = {
description: `The height of the image in pixels`,
isRequired: true,
},
mode: {
type: 'string',
description: `The image mode (ie pixel format)`,
isRequired: true,
},
},
} as const;

View File

@@ -31,7 +31,7 @@ export const $ImageResponse = {
type: 'all-of',
description: `The image's metadata`,
contains: [{
type: 'ImageMetadata',
type: 'ImageResponseMetadata',
}],
isRequired: true,
},

View File

@@ -1,7 +1,7 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $ImageMetadata = {
export const $ImageResponseMetadata = {
description: `An image's metadata. Used only in HTTP responses.`,
properties: {
created: {
@@ -19,11 +19,6 @@ export const $ImageMetadata = {
description: `The height of the image in pixels`,
isRequired: true,
},
mode: {
type: 'string',
description: `The color mode of the image`,
isRequired: true,
},
invokeai: {
type: 'all-of',
description: `The image's InvokeAI-specific metadata`,

View File

@@ -5,14 +5,25 @@ export const $InvokeAIMetadata = {
properties: {
session_id: {
type: 'string',
description: `The session in which this image was created`,
},
node: {
type: 'all-of',
description: `The node that created this image`,
contains: [{
type: 'NodeMetadata',
}],
type: 'dictionary',
contains: {
type: 'any-of',
contains: [{
type: 'string',
}, {
type: 'number',
}, {
type: 'number',
}, {
type: 'boolean',
}, {
type: 'MetadataImageField',
}, {
type: 'MetadataLatentsField',
}],
},
},
},
} as const;

View File

@@ -7,6 +7,7 @@ export const $LatentsField = {
latents_name: {
type: 'string',
description: `The name of the latents`,
isRequired: true,
},
},
} as const;

View File

@@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $MetadataImageField = {
properties: {
image_type: {
type: 'ImageType',
isRequired: true,
},
image_name: {
type: 'string',
isRequired: true,
},
},
} as const;

View File

@@ -0,0 +1,11 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $MetadataLatentsField = {
properties: {
latents_name: {
type: 'string',
isRequired: true,
},
},
} as const;

View File

@@ -1,8 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $NodeMetadata = {
description: `Node metadata model, used for validation of metadata.`,
properties: {
},
} as const;

View File

@@ -26,7 +26,7 @@ export const $RandomRangeInvocation = {
},
seed: {
type: 'number',
description: `The seed for the RNG, provide None or -1 for random`,
description: `The seed for the RNG`,
maximum: 2147483647,
},
},

View File

@@ -177,25 +177,25 @@ export const socketMiddleware = () => {
// Set up listeners for the present subscription
socket.on('invocation_started', (data) => {
if (shouldHandleEvent(data.invocation.id)) {
if (shouldHandleEvent(data.node.id)) {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
}
});
socket.on('generator_progress', (data) => {
if (shouldHandleEvent(data.invocation.id)) {
if (shouldHandleEvent(data.node.id)) {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_error', (data) => {
if (shouldHandleEvent(data.invocation.id)) {
if (shouldHandleEvent(data.node.id)) {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_complete', (data) => {
if (shouldHandleEvent(data.invocation.id)) {
if (shouldHandleEvent(data.node.id)) {
const sessionId = data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;

View File

@@ -30,7 +30,8 @@ export type AnyResult = GraphExecutionState['results'][string];
*/
export type GeneratorProgressEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
node: AnyInvocation;
source_node_id: string;
progress_image?: ProgressImage;
step: number;
total_steps: number;
@@ -45,7 +46,8 @@ export type GeneratorProgressEvent = {
*/
export type InvocationCompleteEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
node: AnyInvocation;
source_node_id: string;
result: AnyResult;
};
@@ -56,7 +58,8 @@ export type InvocationCompleteEvent = {
*/
export type InvocationErrorEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
node: AnyInvocation;
source_node_id: string;
error: string;
};
@@ -67,7 +70,8 @@ export type InvocationErrorEvent = {
*/
export type InvocationStartedEvent = {
graph_execution_state_id: string;
invocation: AnyInvocation;
node: AnyInvocation;
source_node_id: string;
};
/**

View File

@@ -43,6 +43,8 @@ type SessionCreatedArg = {
export const sessionCreated = createAppAsyncThunk(
'api/sessionCreated',
async (arg: SessionCreatedArg, { dispatch, getState }) => {
console.log('Session created, graph: ', arg.graph);
const response = await SessionsService.createSession({
requestBody: arg.graph,
});

View File

@@ -1,4 +1,5 @@
import { Image } from 'app/invokeai';
import { parseInvokeAIMetadata } from 'common/util/parseMetadata';
import { ImageResponse } from 'services/api';
/**
@@ -11,12 +12,18 @@ export const deserializeImageResponse = (
imageResponse;
// TODO: parse metadata - just leaving it as-is for now
const { invokeai, ...rest } = metadata;
const parsedMetadata = parseInvokeAIMetadata(invokeai);
return {
name: image_name,
type: image_type,
url: image_url,
thumbnail: thumbnail_url,
metadata,
metadata: {
...rest,
...(invokeai ? { invokeai: parsedMetadata } : {}),
},
};
};

View File

@@ -0,0 +1,174 @@
export default {};
// python metadata parsing tests to rebuild
// # def test_is_good_metadata_unchanged():
// # parsed_metadata = metadata_service._parse_invokeai_metadata(valid_metadata)
// # expected = deepcopy(valid_metadata)
// # assert expected == parsed_metadata
// # def test_can_parse_missing_session_id():
// # metadata_missing_session_id = deepcopy(valid_metadata)
// # del metadata_missing_session_id["session_id"]
// # expected = deepcopy(valid_metadata)
// # del expected["session_id"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_missing_session_id
// # )
// # assert metadata_missing_session_id == parsed_metadata
// # def test_can_parse_invalid_session_id():
// # metadata_invalid_session_id = deepcopy(valid_metadata)
// # metadata_invalid_session_id["session_id"] = 123
// # expected = deepcopy(valid_metadata)
// # del expected["session_id"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_invalid_session_id
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_missing_node():
// # metadata_missing_node = deepcopy(valid_metadata)
// # del metadata_missing_node["node"]
// # expected = deepcopy(valid_metadata)
// # del expected["node"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_missing_node)
// # assert expected == parsed_metadata
// # def test_can_parse_invalid_node():
// # metadata_invalid_node = deepcopy(valid_metadata)
// # metadata_invalid_node["node"] = 123
// # expected = deepcopy(valid_metadata)
// # del expected["node"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_invalid_node)
// # assert expected == parsed_metadata
// # def test_can_parse_missing_node_id():
// # metadata_missing_node_id = deepcopy(valid_metadata)
// # del metadata_missing_node_id["node"]["id"]
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["id"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_missing_node_id
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_invalid_node_id():
// # metadata_invalid_node_id = deepcopy(valid_metadata)
// # metadata_invalid_node_id["node"]["id"] = 123
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["id"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_invalid_node_id
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_missing_node_type():
// # metadata_missing_node_type = deepcopy(valid_metadata)
// # del metadata_missing_node_type["node"]["type"]
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["type"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_missing_node_type
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_invalid_node_type():
// # metadata_invalid_node_type = deepcopy(valid_metadata)
// # metadata_invalid_node_type["node"]["type"] = 123
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["type"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_invalid_node_type
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_no_node_attrs():
// # metadata_no_node_attrs = deepcopy(valid_metadata)
// # metadata_no_node_attrs["node"] = {}
// # expected = deepcopy(valid_metadata)
// # del expected["node"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_no_node_attrs)
// # assert expected == parsed_metadata
// # def test_can_parse_array_attr():
// # metadata_array_attr = deepcopy(valid_metadata)
// # metadata_array_attr["node"]["seed"] = [1, 2, 3]
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["seed"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_array_attr)
// # assert expected == parsed_metadata
// # def test_can_parse_invalid_dict_attr():
// # metadata_invalid_dict_attr = deepcopy(valid_metadata)
// # metadata_invalid_dict_attr["node"]["seed"] = {"a": 1}
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["seed"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_invalid_dict_attr
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_missing_image_field_image_type():
// # metadata_missing_image_field_image_type = deepcopy(valid_metadata)
// # del metadata_missing_image_field_image_type["node"]["image"]["image_type"]
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["image"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_missing_image_field_image_type
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_invalid_image_field_image_type():
// # metadata_invalid_image_field_image_type = deepcopy(valid_metadata)
// # metadata_invalid_image_field_image_type["node"]["image"][
// # "image_type"
// # ] = "bad image type"
// # expected = deepcopy(valid_metadata)
// # del expected["node"]["image"]
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_invalid_image_field_image_type
// # )
// # assert expected == parsed_metadata
// # def test_can_parse_invalid_latents_field_latents_name():
// # metadata_invalid_latents_field_latents_name = deepcopy(valid_metadata)
// # metadata_invalid_latents_field_latents_name["node"]["latents"] = {
// # "latents_name": 123
// # }
// # expected = deepcopy(valid_metadata)
// # parsed_metadata = metadata_service._parse_invokeai_metadata(
// # metadata_invalid_latents_field_latents_name
// # )
// # assert expected == parsed_metadata

View File

@@ -1,7 +1,6 @@
import json
import os
from copy import deepcopy
from PIL import Image, PngImagePlugin
from invokeai.app.invocations.generate import TextToImageInvocation
@@ -17,7 +16,6 @@ valid_metadata = {
"steps": 30,
"width": 512,
"height": 512,
"image": {"image_type": "results", "image_name": "1"},
"cfg_scale": 7.5,
"scheduler": "k_lms",
"seamless": False,
@@ -29,192 +27,6 @@ valid_metadata = {
metadata_service = PngMetadataService()
def test_is_good_metadata_unchanged():
parsed_metadata = metadata_service._parse_invokeai_metadata(valid_metadata)
expected = deepcopy(valid_metadata)
assert expected == parsed_metadata
def test_can_parse_missing_session_id():
metadata_missing_session_id = deepcopy(valid_metadata)
del metadata_missing_session_id["session_id"]
expected = deepcopy(valid_metadata)
del expected["session_id"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_missing_session_id
)
assert metadata_missing_session_id == parsed_metadata
def test_can_parse_invalid_session_id():
metadata_invalid_session_id = deepcopy(valid_metadata)
metadata_invalid_session_id["session_id"] = 123
expected = deepcopy(valid_metadata)
del expected["session_id"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_invalid_session_id
)
assert expected == parsed_metadata
def test_can_parse_missing_node():
metadata_missing_node = deepcopy(valid_metadata)
del metadata_missing_node["node"]
expected = deepcopy(valid_metadata)
del expected["node"]
parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_missing_node)
assert expected == parsed_metadata
def test_can_parse_invalid_node():
metadata_invalid_node = deepcopy(valid_metadata)
metadata_invalid_node["node"] = 123
expected = deepcopy(valid_metadata)
del expected["node"]
parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_invalid_node)
assert expected == parsed_metadata
def test_can_parse_missing_node_id():
metadata_missing_node_id = deepcopy(valid_metadata)
del metadata_missing_node_id["node"]["id"]
expected = deepcopy(valid_metadata)
del expected["node"]["id"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_missing_node_id
)
assert expected == parsed_metadata
def test_can_parse_invalid_node_id():
metadata_invalid_node_id = deepcopy(valid_metadata)
metadata_invalid_node_id["node"]["id"] = 123
expected = deepcopy(valid_metadata)
del expected["node"]["id"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_invalid_node_id
)
assert expected == parsed_metadata
def test_can_parse_missing_node_type():
metadata_missing_node_type = deepcopy(valid_metadata)
del metadata_missing_node_type["node"]["type"]
expected = deepcopy(valid_metadata)
del expected["node"]["type"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_missing_node_type
)
assert expected == parsed_metadata
def test_can_parse_invalid_node_type():
metadata_invalid_node_type = deepcopy(valid_metadata)
metadata_invalid_node_type["node"]["type"] = 123
expected = deepcopy(valid_metadata)
del expected["node"]["type"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_invalid_node_type
)
assert expected == parsed_metadata
def test_can_parse_no_node_attrs():
metadata_no_node_attrs = deepcopy(valid_metadata)
metadata_no_node_attrs["node"] = {}
expected = deepcopy(valid_metadata)
del expected["node"]
parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_no_node_attrs)
assert expected == parsed_metadata
def test_can_parse_array_attr():
metadata_array_attr = deepcopy(valid_metadata)
metadata_array_attr["node"]["seed"] = [1, 2, 3]
expected = deepcopy(valid_metadata)
del expected["node"]["seed"]
parsed_metadata = metadata_service._parse_invokeai_metadata(metadata_array_attr)
assert expected == parsed_metadata
def test_can_parse_invalid_dict_attr():
metadata_invalid_dict_attr = deepcopy(valid_metadata)
metadata_invalid_dict_attr["node"]["seed"] = {"a": 1}
expected = deepcopy(valid_metadata)
del expected["node"]["seed"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_invalid_dict_attr
)
assert expected == parsed_metadata
def test_can_parse_missing_image_field_image_type():
metadata_missing_image_field_image_type = deepcopy(valid_metadata)
del metadata_missing_image_field_image_type["node"]["image"]["image_type"]
expected = deepcopy(valid_metadata)
del expected["node"]["image"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_missing_image_field_image_type
)
assert expected == parsed_metadata
def test_can_parse_invalid_image_field_image_type():
metadata_invalid_image_field_image_type = deepcopy(valid_metadata)
metadata_invalid_image_field_image_type["node"]["image"][
"image_type"
] = "bad image type"
expected = deepcopy(valid_metadata)
del expected["node"]["image"]
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_invalid_image_field_image_type
)
assert expected == parsed_metadata
def test_can_parse_invalid_latents_field_latents_name():
metadata_invalid_latents_field_latents_name = deepcopy(valid_metadata)
metadata_invalid_latents_field_latents_name["node"]["latents"] = {
"latents_name": 123
}
expected = deepcopy(valid_metadata)
parsed_metadata = metadata_service._parse_invokeai_metadata(
metadata_invalid_latents_field_latents_name
)
assert expected == parsed_metadata
def test_can_load_and_parse_invokeai_metadata(tmp_path):
raw_metadata = {"session_id": "123", "node": {"id": "456", "type": "test_type"}}
@@ -228,49 +40,16 @@ def test_can_load_and_parse_invokeai_metadata(tmp_path):
image = Image.open(temp_image_path)
loaded_metadata = metadata_service._load_metadata(image)
parsed_metadata = metadata_service._parse_invokeai_metadata(loaded_metadata)
loaded_and_parsed_metadata = metadata_service.get_metadata(image)
loaded_metadata = metadata_service.get_metadata(image)
assert raw_metadata == loaded_metadata
assert raw_metadata == parsed_metadata
assert raw_metadata == loaded_and_parsed_metadata
assert loaded_metadata is not None
assert raw_metadata == loaded_metadata["invokeai"]
def test_can_build_invokeai_metadata():
session_id = "123"
invocation = TextToImageInvocation(
id="456",
prompt="test",
seed=1,
steps=10,
width=512,
height=512,
cfg_scale=7.5,
scheduler="k_lms",
seamless=False,
model="test_mode",
progress_images=True,
)
session_id = valid_metadata["session_id"]
node = TextToImageInvocation(**valid_metadata["node"])
metadata = metadata_service.build_metadata(session_id=session_id, node=invocation)
metadata = metadata_service.build_metadata(session_id=session_id, node=node)
expected_metadata_dict = {
"session_id": "123",
"node": {
"id": "456",
"type": "txt2img",
"prompt": "test",
"seed": 1,
"steps": 10,
"width": 512,
"height": 512,
"cfg_scale": 7.5,
"scheduler": "k_lms",
"seamless": False,
"model": "test_mode",
"progress_images": True,
},
}
assert expected_metadata_dict == metadata
assert valid_metadata == metadata