diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 6bf936a55c..bf85f4f468 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -9,7 +9,7 @@ import type { UnrecallableMetadataHandler, } from 'features/metadata/parsing'; import { - MetadataHanders, + MetadataHandlers, useCollectionMetadataDatum, useSingleMetadataDatum, useUnrecallableMetadataDatum, @@ -30,33 +30,33 @@ const ImageMetadataActions = (props: Props) => { return ( - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + ); }; @@ -84,7 +84,7 @@ const UnrecallableMetadataParsed = typedMemo( return ( - + ); @@ -127,7 +127,7 @@ const SingleMetadataParsed = typedMemo( onClick={onClick} /> - + @@ -148,7 +148,7 @@ const CollectionMetadataDatum = typedMemo( return ( <> {data.value.map((value, i) => ( - + ))} ); @@ -158,17 +158,7 @@ const CollectionMetadataDatum = typedMemo( CollectionMetadataDatum.displayName = 'CollectionMetadataDatum'; const CollectionMetadataParsed = typedMemo( - ({ - value, - i, - data, - handler, - }: { - value: T[number]; - i: number; - data: ParsedSuccessData; - handler: CollectionMetadataHandler; - }) => { + ({ value, handler }: { value: T[number]; handler: CollectionMetadataHandler }) => { const store = useAppStore(); const { LabelComponent, ValueComponent } = handler; @@ -187,7 +177,7 @@ const CollectionMetadataParsed = typedMemo( onClick={onClick} /> - + diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index e3b45fd7f8..f42601078f 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -13,6 +13,7 @@ import { positivePrompt2Changed, positivePromptChanged, refinerModelChanged, + selectBase, setCfgRescaleMultiplier, setCfgScale, setGuidance, @@ -28,6 +29,7 @@ import { setSeamlessYAxis, setSeed, setSteps, + shouldConcatPromptsChanged, vaeSelected, } from 'features/controlLayers/store/paramsSlice'; import type { LoRA } from 'features/controlLayers/store/types'; @@ -77,7 +79,9 @@ import { zParameterSteps, zParameterStrength, } from 'features/parameters/types/parameterSchemas'; -import type { ComponentType } from 'react'; +import { toast } from 'features/toast/toast'; +import { t } from 'i18next'; +import type { ComponentType, ReactNode } from 'react'; import { useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { modelsApi } from 'services/api/endpoints/models'; @@ -94,15 +98,6 @@ const MetadataLabel = ({ i18nKey }: { i18nKey: string }) => { ); }; -const MetadataLabelWithCount = ({ i18nKey, i }: { i18nKey: string; i: number; values: T }) => { - const { t } = useTranslation(); - return ( - - {`${t(i18nKey)} ${i + 1}:`} - - ); -}; - const MetadataPrimitiveValue = ({ value }: { value: string | number | boolean | null | undefined }) => { return {value}; }; @@ -158,53 +153,60 @@ const buildParsedErrorData = (error: Error): ParsedErrorData => ({ export type Data = UnparsedData | ParsedSuccessData | ParsedErrorData; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -type SingleMetadataLabelProps = { - value: T; -}; +const SingleMetadataKey = Symbol('SingleMetadataKey'); type SingleMetadataValueProps = { value: T; }; export type SingleMetadataHandler = { + [SingleMetadataKey]: true; type: string; parse: (metadata: unknown, store: AppStore) => Promise; recall: (value: T, store: AppStore) => void; - LabelComponent: ComponentType>; + LabelComponent: ComponentType; ValueComponent: ComponentType>; }; -type CollectionMetadataLabelProps = { - values: T; - i: number; -}; +const CollectionMetadataKey = Symbol('CollectionMetadataKey'); type CollectionMetadataValueProps = { value: T[number]; }; export type CollectionMetadataHandler = { + [CollectionMetadataKey]: true; type: string; parse: (metadata: unknown, store: AppStore) => Promise; - recallAll: (values: T, store: AppStore) => void; + recall: (values: T, store: AppStore) => void; recallOne: (value: T[number], store: AppStore) => void; - LabelComponent: ComponentType>; + LabelComponent: ComponentType; ValueComponent: ComponentType>; }; -/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ -type UnrecallableMetadataLabelProps = { - value: T; -}; +const UnrecallableMetadataKey = Symbol('UnrecallableMetadataKey'); type UnrecallableMetadataValueProps = { value: T; }; export type UnrecallableMetadataHandler = { + [UnrecallableMetadataKey]: true; type: string; parse: (metadata: unknown, store: AppStore) => Promise; - LabelComponent: ComponentType>; + LabelComponent: ComponentType; ValueComponent: ComponentType>; }; +const isSingleMetadataHandler = ( + handler: SingleMetadataHandler | CollectionMetadataHandler | UnrecallableMetadataHandler +): handler is SingleMetadataHandler => { + return SingleMetadataKey in handler && handler[SingleMetadataKey] === true; +}; + +const isCollectionMetadataHandler = ( + handler: SingleMetadataHandler | CollectionMetadataHandler | UnrecallableMetadataHandler +): handler is CollectionMetadataHandler => { + return CollectionMetadataKey in handler && handler[CollectionMetadataKey] === true; +}; + //#region Created By const CreatedBy: UnrecallableMetadataHandler = { + [UnrecallableMetadataKey]: true, type: 'CreatedBy', parse: (metadata, _store) => { const raw = getProperty(metadata, 'created_by'); @@ -212,12 +214,13 @@ const CreatedBy: UnrecallableMetadataHandler = { return Promise.resolve(parsed); }, LabelComponent: () => , - ValueComponent: ({ value }: UnrecallableMetadataLabelProps) => , + ValueComponent: ({ value }: UnrecallableMetadataValueProps) => , }; //#endregion Created By //#region Generation Mode const GenerationMode: UnrecallableMetadataHandler = { + [UnrecallableMetadataKey]: true, type: 'GenerationMode', parse: (metadata, _store) => { const raw = getProperty(metadata, 'generation_mode'); @@ -225,12 +228,13 @@ const GenerationMode: UnrecallableMetadataHandler = { return Promise.resolve(parsed); }, LabelComponent: () => , - ValueComponent: ({ value }: UnrecallableMetadataLabelProps) => , + ValueComponent: ({ value }: UnrecallableMetadataValueProps) => , }; //#endregion Generation Mode //#region Positive Prompt const PositivePrompt: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'PositivePrompt', parse: (metadata, _store) => { const raw = getProperty(metadata, 'positive_prompt'); @@ -249,6 +253,7 @@ const PositivePrompt: SingleMetadataHandler = { //#region Negative Prompt const NegativePrompt: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'NegativePrompt', parse: (metadata, _store) => { const raw = getProperty(metadata, 'negative_prompt'); @@ -267,6 +272,7 @@ const NegativePrompt: SingleMetadataHandler = { //#region SDXL Positive Style Prompt const PositiveStylePrompt: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'PositiveStylePrompt', parse: (metadata, _store) => { const raw = getProperty(metadata, 'positive_style_prompt'); @@ -285,6 +291,7 @@ const PositiveStylePrompt: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'NegativeStylePrompt', parse: (metadata, _store) => { const raw = getProperty(metadata, 'negative_style_prompt'); @@ -303,6 +310,7 @@ const NegativeStylePrompt: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'CFGScale', parse: (metadata, _store) => { const raw = getProperty(metadata, 'cfg_scale'); @@ -319,6 +327,7 @@ const CFGScale: SingleMetadataHandler = { //#region CFG Rescale Multiplier const CFGRescaleMultiplier: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'CFGRescaleMultiplier', parse: (metadata, _store) => { const raw = getProperty(metadata, 'cfg_rescale_multiplier'); @@ -337,6 +346,7 @@ const CFGRescaleMultiplier: SingleMetadataHandler //#region Guidance const Guidance: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'Guidance', parse: (metadata, _store) => { const raw = getProperty(metadata, 'guidance'); @@ -353,6 +363,7 @@ const Guidance: SingleMetadataHandler = { //#region Scheduler const Scheduler: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'Scheduler', parse: (metadata, _store) => { const raw = getProperty(metadata, 'scheduler'); @@ -369,6 +380,7 @@ const Scheduler: SingleMetadataHandler = { //#region Width const Width: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'Width', parse: (metadata, _store) => { const raw = getProperty(metadata, 'width'); @@ -385,6 +397,7 @@ const Width: SingleMetadataHandler = { //#region Height const Height: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'Height', parse: (metadata, _store) => { const raw = getProperty(metadata, 'height'); @@ -401,6 +414,7 @@ const Height: SingleMetadataHandler = { //#region Seed const Seed: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'Seed', parse: (metadata, _store) => { const raw = getProperty(metadata, 'seed'); @@ -417,6 +431,7 @@ const Seed: SingleMetadataHandler = { //#region Steps const Steps: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'Steps', parse: (metadata, _store) => { const raw = getProperty(metadata, 'steps'); @@ -433,6 +448,7 @@ const Steps: SingleMetadataHandler = { //#region DenoisingStrength const DenoisingStrength: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'DenoisingStrength', parse: (metadata, _store) => { const raw = getProperty(metadata, 'strength'); @@ -449,6 +465,7 @@ const DenoisingStrength: SingleMetadataHandler = { //#region SeamlessX const SeamlessX: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'SeamlessX', parse: (metadata, _store) => { const raw = getProperty(metadata, 'seamless_x'); @@ -465,6 +482,7 @@ const SeamlessX: SingleMetadataHandler = { //#region SeamlessY const SeamlessY: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'SeamlessY', parse: (metadata, _store) => { const raw = getProperty(metadata, 'seamless_y'); @@ -481,12 +499,14 @@ const SeamlessY: SingleMetadataHandler = { //#region RefinerModel const RefinerModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerModel', parse: async (metadata, store) => { const raw = getProperty(metadata, 'refiner_model'); const parsed = await parseModelIdentifier(raw, store, 'main'); assert(parsed.type === 'main'); assert(parsed.base === 'sdxl-refiner'); + assert(isCompatibleWithMainModel(parsed, store)); return Promise.resolve(parsed); }, recall: (value, store) => { @@ -501,6 +521,7 @@ const RefinerModel: SingleMetadataHandler = { //#region RefinerSteps const RefinerSteps: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerSteps', parse: (metadata, _store) => { const raw = getProperty(metadata, 'refiner_steps'); @@ -517,6 +538,7 @@ const RefinerSteps: SingleMetadataHandler = { //#region RefinerCFGScale const RefinerCFGScale: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerCFGScale', parse: (metadata, _store) => { const raw = getProperty(metadata, 'refiner_cfg_scale'); @@ -533,6 +555,7 @@ const RefinerCFGScale: SingleMetadataHandler = { //#region RefinerScheduler const RefinerScheduler: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerScheduler', parse: (metadata, _store) => { const raw = getProperty(metadata, 'refiner_scheduler'); @@ -549,6 +572,7 @@ const RefinerScheduler: SingleMetadataHandler = { //#region RefinerPositiveAestheticScore const RefinerPositiveAestheticScore: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerPositiveAestheticScore', parse: (metadata, _store) => { const raw = getProperty(metadata, 'refiner_positive_aesthetic_score'); @@ -567,6 +591,7 @@ const RefinerPositiveAestheticScore: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerNegativeAestheticScore', parse: (metadata, _store) => { const raw = getProperty(metadata, 'refiner_negative_aesthetic_score'); @@ -585,6 +610,7 @@ const RefinerNegativeAestheticScore: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'RefinerDenoisingStart', parse: (metadata, _store) => { const raw = getProperty(metadata, 'refiner_start'); @@ -603,6 +629,7 @@ const RefinerDenoisingStart: SingleMetadataHandler = //#region MainModel const MainModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'MainModel', parse: async (metadata, store) => { const raw = getProperty(metadata, 'model'); @@ -622,11 +649,13 @@ const MainModel: SingleMetadataHandler = { //#region VAEModel const VAEModel: SingleMetadataHandler = { + [SingleMetadataKey]: true, type: 'VAEModel', parse: async (metadata, store) => { const raw = getProperty(metadata, 'vae'); const parsed = await parseModelIdentifier(raw, store, 'vae'); assert(parsed.type === 'vae'); + assert(isCompatibleWithMainModel(parsed, store)); return Promise.resolve(parsed); }, recall: (value, store) => { @@ -641,6 +670,7 @@ const VAEModel: SingleMetadataHandler = { //#region LoRAs const LoRAs: CollectionMetadataHandler = { + [CollectionMetadataKey]: true, type: 'LoRAs', parse: async (metadata, store) => { const rawArray = getProperty(metadata, 'loras'); @@ -651,6 +681,7 @@ const LoRAs: CollectionMetadataHandler = { for (const rawItem of rawArray) { try { let identifier: ModelIdentifierField | null = null; + try { // New format - { model: ModelIdenfifierField } const rawIdentifier = getProperty(rawItem, 'model'); @@ -662,8 +693,12 @@ const LoRAs: CollectionMetadataHandler = { // No need to catch here - if this throws, we move on to the next item identifier = await getModelIdentiferFromKey(key, store); } + assert(identifier.type === 'lora'); + assert(isCompatibleWithMainModel(identifier, store)); + const weight = getProperty(rawItem, 'weight'); + loras.push({ id: getPrefixedId('lora'), model: identifier, @@ -684,22 +719,20 @@ const LoRAs: CollectionMetadataHandler = { recallOne: (value, store) => { store.dispatch(loraRecalled({ lora: value })); }, - recallAll: (values, store) => { + recall: (values, store) => { store.dispatch(loraAllDeleted()); for (const lora of values) { store.dispatch(loraRecalled({ lora })); } }, - LabelComponent: ({ values, i }: CollectionMetadataLabelProps) => ( - - ), + LabelComponent: () => , ValueComponent: ({ value }: CollectionMetadataValueProps) => ( ), }; //#endregion LoRAs -export const MetadataHanders = { +export const MetadataHandlers = { CreatedBy, GenerationMode, PositivePrompt, @@ -727,10 +760,174 @@ export const MetadataHanders = { MainModel, VAEModel, LoRAs, -} satisfies Record< - string, - UnrecallableMetadataHandler | SingleMetadataHandler | CollectionMetadataHandler ->; +} as const; + +const successToast = (parameter: ReactNode) => { + toast({ + id: 'PARAMETER_SET', + title: t('toast.parameterSet'), + description: t('toast.parameterSetDesc', { parameter }), + status: 'info', + }); +}; + +const failedToast = (parameter: ReactNode, message?: ReactNode) => { + toast({ + id: 'PARAMETER_NOT_SET', + title: t('toast.parameterNotSet'), + description: message + ? t('toast.parameterNotSetDescWithMessage', { parameter, message }) + : t('toast.parameterNotSetDesc', { parameter }), + status: 'warning', + }); +}; + +const recallByHandler = async (arg: { + metadata: unknown; + handler: SingleMetadataHandler | CollectionMetadataHandler; + store: AppStore; + silent?: boolean; +}): Promise => { + const { metadata, handler, store, silent = false } = arg; + + let didRecall = false; + + try { + const value = await handler.parse(metadata, store); + handler.recall(value, store); + didRecall = true; + } catch { + // + } + + if (!silent) { + if (didRecall) { + successToast(); + } else { + failedToast(); + } + } + + return didRecall; +}; + +const recallByHandlers = async (arg: { + metadata: unknown; + handlers: (SingleMetadataHandler | CollectionMetadataHandler)[]; + store: AppStore; + silent?: boolean; +}): Promise | CollectionMetadataHandler, unknown>> => { + const { metadata, handlers, store, silent = false } = arg; + + const recalled = new Map | CollectionMetadataHandler, unknown>(); + + // It's possible for some metadata item's recall to clobber the recall of another. For example, the model recall + // may change the width and height. If we are also recalling the width and height directly, we need to ensure that the + // model is recalled first, so it doesn't accidentally override the width and height. This is the only known case + // where the order of recall matters. + const sortedHandlers = handlers.sort((a, b) => { + if (a === MetadataHandlers.MainModel) { + return -1; // MainModel should be recalled first + } else if (b === MetadataHandlers.MainModel) { + return 1; // MainModel should be recalled first + } else { + return 0; // Keep the original order for other handlers + } + }); + + for (const handler of sortedHandlers) { + try { + const value = await handler.parse(metadata, store); + handler.recall(value, store); + recalled.set(handler, value); + } catch (error) { + // + } + } + + // If we recalled style prompts, and they were _different_ from the positive prompt, we need to disable prompt concat. + const positivePrompt = recalled.get(MetadataHandlers.PositivePrompt); + const negativePrompt = recalled.get(MetadataHandlers.NegativePrompt); + const positiveStylePrompt = recalled.get(MetadataHandlers.PositiveStylePrompt); + const negativeStylePrompt = recalled.get(MetadataHandlers.NegativeStylePrompt); + + if ( + (positiveStylePrompt && positiveStylePrompt !== positivePrompt) || + (negativeStylePrompt && negativeStylePrompt !== negativePrompt) + ) { + // If we set the negative style prompt or positive style prompt, we should disable prompt concat + store.dispatch(shouldConcatPromptsChanged(false)); + } else { + // Otherwise, we should enable prompt concat + store.dispatch(shouldConcatPromptsChanged(true)); + } + + if (!silent) { + if (recalled.size > 0) { + toast({ + id: 'PARAMETER_SET', + title: t('toast.parametersSet'), + status: 'info', + }); + } else { + toast({ + id: 'PARAMETER_SET', + title: t('toast.parametersNotSet'), + status: 'warning', + }); + } + } + + return recalled; +}; + +const recallPrompts = async (metadata: unknown, store: AppStore) => { + const recalled = await recallByHandlers({ + metadata, + handlers: [ + MetadataHandlers.PositivePrompt, + MetadataHandlers.NegativePrompt, + MetadataHandlers.PositiveStylePrompt, + MetadataHandlers.NegativeStylePrompt, + ], + store, + silent: true, + }); + if (recalled.size > 0) { + successToast(t('metadata.allPrompts')); + } +}; + +const recallDimensions = async (metadata: unknown, store: AppStore) => { + const recalled = await recallByHandlers({ + metadata, + handlers: [MetadataHandlers.Width, MetadataHandlers.Height], + store, + silent: true, + }); + if (recalled.size > 0) { + successToast(t('metadata.imageDimensions')); + } +}; + +const recallAll = async (metadata: unknown, store: AppStore) => { + const handlers = Object.values(MetadataHandlers).filter( + (handler) => isSingleMetadataHandler(handler) || isCollectionMetadataHandler(handler) + ); + await recallByHandlers({ + metadata, + handlers, + store, + }); +}; + +export const MetadataUtils = { + recallByHandler, + recallByHandlers, + recallAll, + recallPrompts, + recallDimensions, +} as const; export function useSingleMetadataDatum(metadata: unknown, handler: SingleMetadataHandler) { const store = useAppStore(); @@ -791,7 +988,7 @@ export function useCollectionMetadataDatum(metadata: unknown, h const recallAll = useCallback( (values: T) => { - handler.recallAll(values, store); + handler.recall(values, store); }, [handler, store] ); @@ -829,34 +1026,38 @@ export function useUnrecallableMetadataDatum(metadata: unknown, handler: Unre return { data }; } -const OPTIONS = { subscribe: false }; +const options = { subscribe: false }; const getModelIdentiferFromKey = async (key: string, store: AppStore): Promise => { - const req = store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, OPTIONS)); + const req = store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, options)); const modelConfig = await req.unwrap(); return modelConfig; }; const parseModelIdentifier = async (raw: unknown, store: AppStore, type: ModelType): Promise => { - // First try the current format identifier: key, name, base, type, hash try { + // First try the current format identifier: key, name, base, type, hash const { key } = zModelIdentifierField.parse(raw); - const req = store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, OPTIONS)); + const req = store.dispatch(modelsApi.endpoints.getModelConfig.initiate(key, options)); const modelConfig = await req.unwrap(); return zModelIdentifierField.parse(modelConfig); } catch { - // noop + // We'll try to parse the old format identifier next } // Fall back to old format identifier: model_name, base_model - try { - const { model_name: name, base_model: base } = zModelIdentifier.parse(raw); - const arg = { name, base, type }; - const req = store.dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate(arg, OPTIONS)); - const modelConfig = await req.unwrap(); - return zModelIdentifierField.parse(modelConfig); - } catch { - // noop - } - throw new Error('Unable to parse model identifier'); + // No error handling here - this is our last chance to get a model identifier + const { model_name, base_model } = zModelIdentifier.parse(raw); + const arg = { name: model_name, base: base_model, type }; + const req = store.dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate(arg, options)); + const modelConfig = await req.unwrap(); + return zModelIdentifierField.parse(modelConfig); +}; + +const isCompatibleWithMainModel = (candidate: ModelIdentifierField, store: AppStore) => { + const base = selectBase(store.getState()); + if (!base) { + return true; + } + return candidate.base === base; };