From 64523c4b1b4add4d705cb98abe1b282af35c8bea Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:34:35 +1000 Subject: [PATCH] fix(ui): handle concat when recalling prompts This required some minor reworking of of the logic to recall multiple items. I split this into a utility function that includes some special handling for concat. Closes #6478 --- .../src/features/metadata/util/handlers.ts | 123 +++++++++--------- 1 file changed, 65 insertions(+), 58 deletions(-) diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index 2829507dcd..33715cbbe1 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -1,4 +1,7 @@ +import { getStore } from 'app/store/nanostores/store'; +import { deepClone } from 'common/util/deepClone'; import { objectKeys } from 'common/util/objectKeys'; +import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice'; import type { Layer } from 'features/controlLayers/store/types'; import type { LoRA } from 'features/lora/store/loraSlice'; import type { @@ -16,6 +19,7 @@ import { validators } from 'features/metadata/util/validators'; import type { ModelIdentifierField } from 'features/nodes/types/common'; import { toast } from 'features/toast/toast'; import { t } from 'i18next'; +import { size } from 'lodash-es'; import { assert } from 'tsafe'; import { parsers } from './parsers'; @@ -376,54 +380,25 @@ export const handlers = { }), } as const; +type ParsedValue = Awaited>; +type RecallResults = Partial>; + export const parseAndRecallPrompts = async (metadata: unknown) => { - const results = await Promise.allSettled([ - handlers.positivePrompt.parse(metadata).then((positivePrompt) => { - if (!handlers.positivePrompt.recall) { - return; - } - handlers.positivePrompt?.recall(positivePrompt); - }), - handlers.negativePrompt.parse(metadata).then((negativePrompt) => { - if (!handlers.negativePrompt.recall) { - return; - } - handlers.negativePrompt?.recall(negativePrompt); - }), - handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => { - if (!handlers.sdxlPositiveStylePrompt.recall) { - return; - } - handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt); - }), - handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => { - if (!handlers.sdxlNegativeStylePrompt.recall) { - return; - } - handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt); - }), - ]); - if (results.some((result) => result.status === 'fulfilled')) { + const keysToRecall: (keyof typeof handlers)[] = [ + 'positivePrompt', + 'negativePrompt', + 'sdxlPositiveStylePrompt', + 'sdxlNegativeStylePrompt', + ]; + const recalled = await recallKeys(keysToRecall, metadata); + if (size(recalled) > 0) { parameterSetToast(t('metadata.allPrompts')); } }; export const parseAndRecallImageDimensions = async (metadata: unknown) => { - const results = await Promise.allSettled([ - handlers.width.parse(metadata).then((width) => { - if (!handlers.width.recall) { - return; - } - handlers.width?.recall(width); - }), - handlers.height.parse(metadata).then((height) => { - if (!handlers.height.recall) { - return; - } - handlers.height?.recall(height); - }), - ]); - if (results.some((result) => result.status === 'fulfilled')) { + const recalled = recallKeys(['width', 'height'], metadata); + if (size(recalled) > 0) { parameterSetToast(t('metadata.imageDimensions')); } }; @@ -438,28 +413,20 @@ export const parseAndRecallAllMetadata = async ( toControlLayers: boolean, skip: (keyof typeof handlers)[] = [] ) => { - const skipKeys = skip ?? []; + const skipKeys = deepClone(skip); if (toControlLayers) { skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS); } else { skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS); } - const results = await Promise.allSettled( - objectKeys(handlers) - .filter((key) => !skipKeys.includes(key)) - .map((key) => { - const { parse, recall } = handlers[key]; - return parse(metadata).then((value) => { - if (!recall) { - return; - } - /* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */ - recall(value); - }); - }) - ); - if (results.some((result) => result.status === 'fulfilled')) { + // We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt + // concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all + // results + const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key)); + const recalled = await recallKeys(keysToRecall, metadata); + + if (size(recalled) > 0) { toast({ id: 'PARAMETER_SET', title: t('toast.parametersSet'), @@ -473,3 +440,43 @@ export const parseAndRecallAllMetadata = async ( }); } }; + +/** + * Recalls a set of keys from metadata. + * Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style" + * prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation. + * @param keysToRecall An array of keys to recall. + * @param metadata The metadata to recall from + * @returns A promise that resolves to an object containing the recalled values. + */ +const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise => { + const { dispatch } = getStore(); + const recalled: RecallResults = {}; + for (const key of keysToRecall) { + const { parse, recall } = handlers[key]; + if (!recall) { + continue; + } + try { + const value = await parse(metadata); + /* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */ + await recall(value); + recalled[key] = value; + } catch { + // no-op + } + } + + if ( + (recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) || + (recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt']) + ) { + // If we set the negative style prompt or positive style prompt, we should disable prompt concat + dispatch(shouldConcatPromptsChanged(false)); + } else { + // Otherwise, we should enable prompt concat + dispatch(shouldConcatPromptsChanged(true)); + } + + return recalled; +};