mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(nodes): make MetadataInvocation.items polymorphic
This commit is contained in:
@@ -178,6 +178,8 @@ class UIType(str, Enum):
|
||||
BoardField = "BoardField"
|
||||
Any = "Any"
|
||||
MetadataItem = "MetadataItem"
|
||||
MetadataItemCollection = "MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||
MetadataDict = "MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -217,7 +217,10 @@ class MetadataDictOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataInvocation(BaseInvocation):
|
||||
items: list[MetadataItem] = InputField(description="List of metadata items")
|
||||
items: Union[list[MetadataItem], MetadataItem] = InputField(description="List of metadata items")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataDictOutput:
|
||||
if isinstance(self.items, MetadataItem):
|
||||
return MetadataDictOutput(metadata_dict=(MetadataDict(data={self.items.label: self.items.value})))
|
||||
|
||||
return MetadataDictOutput(metadata_dict=(MetadataDict(data={item.label: item.value for item in self.items})))
|
||||
|
||||
@@ -44,6 +44,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
'MetadataItemPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES: FieldType[] = [
|
||||
@@ -89,6 +90,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
MetadataItem: 'MetadataItemPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
@@ -101,6 +103,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
MetadataItemPolymorphic: 'MetadataItem',
|
||||
};
|
||||
|
||||
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
|
||||
@@ -150,9 +153,15 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
},
|
||||
MetadataItemCollection: {
|
||||
color: 'gray.500',
|
||||
description: 'MetadataItemCollection field type is accepted.',
|
||||
description: 'Any field type is accepted.',
|
||||
title: 'Metadata Item Collection',
|
||||
},
|
||||
MetadataItemPolymorphic: {
|
||||
color: 'gray.500',
|
||||
description:
|
||||
'MetadataItem or MetadataItemCollection field types are accepted.',
|
||||
title: 'Metadata Item Polymorphic',
|
||||
},
|
||||
boolean: {
|
||||
color: 'green.500',
|
||||
description: t('nodes.booleanDescription'),
|
||||
|
||||
@@ -111,6 +111,7 @@ export const zFieldType = z.enum([
|
||||
'MetadataDict',
|
||||
'MetadataItem',
|
||||
'MetadataItemCollection',
|
||||
'MetadataItemPolymorphic',
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
'SDXLMainModelField',
|
||||
@@ -634,6 +635,15 @@ export type MetadataItemCollectionInputFieldValue = z.infer<
|
||||
typeof zMetadataItemCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataItemPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('MetadataItemPolymorphic'),
|
||||
value: z.union([zMetadataItem, z.array(zMetadataItem)]).optional(),
|
||||
});
|
||||
export type MetadataItemPolymorphicInputFieldValue = z.infer<
|
||||
typeof zMetadataItemPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zMetadataDict = z.record(z.any());
|
||||
export type MetadataDict = z.infer<typeof zMetadataDict>;
|
||||
|
||||
@@ -736,6 +746,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zVaeModelInputFieldValue,
|
||||
zMetadataItemInputFieldValue,
|
||||
zMetadataItemCollectionInputFieldValue,
|
||||
zMetadataItemPolymorphicInputFieldValue,
|
||||
zMetadataDictInputFieldValue,
|
||||
]);
|
||||
|
||||
@@ -1035,6 +1046,13 @@ export type MetadataItemCollectionInputFieldTemplate =
|
||||
type: 'MetadataItemCollection';
|
||||
};
|
||||
|
||||
export type MetadataItemPolymorphicInputFieldTemplate = Omit<
|
||||
MetadataItemInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'MetadataItemPolymorphic';
|
||||
};
|
||||
|
||||
export type MetadataDictInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'MetadataDict';
|
||||
@@ -1094,7 +1112,8 @@ export type InputFieldTemplate =
|
||||
| VaeModelInputFieldTemplate
|
||||
| MetadataItemInputFieldTemplate
|
||||
| MetadataItemCollectionInputFieldTemplate
|
||||
| MetadataDictInputFieldTemplate;
|
||||
| MetadataDictInputFieldTemplate
|
||||
| MetadataItemPolymorphicInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
|
||||
@@ -56,6 +56,7 @@ import {
|
||||
MetadataDictInputFieldTemplate,
|
||||
MetadataItemCollectionInputFieldTemplate,
|
||||
MetadataItemInputFieldTemplate,
|
||||
MetadataItemPolymorphicInputFieldTemplate,
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
@@ -771,6 +772,18 @@ const buildMetadataItemCollectionInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataItemPolymorphicInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataItemPolymorphicInputFieldTemplate => {
|
||||
const template: MetadataItemPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'MetadataItemPolymorphic',
|
||||
default: undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMetadataDictInputFieldTemplate = ({
|
||||
baseField,
|
||||
}: BuildInputFieldArg): MetadataDictInputFieldTemplate => {
|
||||
@@ -958,6 +971,7 @@ const TEMPLATE_BUILDER_MAP: {
|
||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||
MetadataItem: buildMetadataItemInputFieldTemplate,
|
||||
MetadataItemCollection: buildMetadataItemCollectionInputFieldTemplate,
|
||||
MetadataItemPolymorphic: buildMetadataItemPolymorphicInputFieldTemplate,
|
||||
MetadataDict: buildMetadataDictInputFieldTemplate,
|
||||
MainModelField: buildMainModelInputFieldTemplate,
|
||||
Scheduler: buildSchedulerInputFieldTemplate,
|
||||
|
||||
@@ -39,6 +39,7 @@ const FIELD_VALUE_FALLBACK_MAP: {
|
||||
LatentsPolymorphic: undefined,
|
||||
MetadataItem: undefined,
|
||||
MetadataItemCollection: [],
|
||||
MetadataItemPolymorphic: undefined,
|
||||
MetadataDict: {},
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
|
||||
@@ -9634,7 +9634,7 @@ export type components = {
|
||||
*/
|
||||
StableDiffusionOnnxModelFormat: 'olive' | 'onnx';
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user