From c550ce31a9487e7c6d1ed18efa24632ac2804693 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 20 Apr 2026 16:58:07 -0400 Subject: [PATCH 1/2] fix(ui): FLUX.2 Klein VAE/Qwen3 readiness checks and diffusers source auto-detection (#9041) * fix(ui): FLUX.2 Klein VAE/Qwen3 readiness checks and diffusers source auto-detection Fix several issues with FLUX.2 Klein model handling: 1. Add readiness validation for non-diffusers Klein models so the invoke button is disabled when required VAE/Qwen3 submodels are missing. 2. Auto-detect installed diffusers flux2 models and pass them as qwen3_source_model in the graph builder, so GGUF/safetensors models can extract VAE and encoder from an available diffusers model. 3. Use variant-aware matching so Klein 9B models pick a 9B diffusers source (not 4B), preventing Qwen3 encoder dimension mismatches. 4. Change placeholder text from "From main model" to "From diffusers model" or "No diffusers model available" depending on availability. 5. Export readiness check functions and add comprehensive tests for both the graph builder and readiness logic. Co-Authored-By: Claude Opus 4.6 (1M context) * Chore Fix merge * fix(ui): unify FLUX.2 Klein Qwen3 variant matching Extract KLEIN_TO_QWEN3_VARIANT_MAP and isFlux2KleinQwen3Compatible into features/parameters/util/flux2Klein so UI placeholder, readiness check, and graph builder share one rule. Accepts klein_9b and klein_9b_base as mutual Qwen3 sources (both use qwen3_8b) and guards against undefined === undefined false positives. Use zModelIdentifierField.parse for qwen3_source_model construction in buildFLUXGraph, matching the pattern used for Z-Image. --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Alexander Eichhorn --- invokeai/frontend/web/public/locales/en.json | 8 +- .../graph/generation/buildFLUXGraph.test.ts | 370 ++++++++++++++++++ .../util/graph/generation/buildFLUXGraph.ts | 24 +- .../Advanced/ParamFlux2KleinModelSelect.tsx | 38 +- .../features/parameters/util/flux2Klein.ts | 24 ++ .../features/queue/store/readiness.test.ts | 272 +++++++++++++ .../web/src/features/queue/store/readiness.ts | 63 ++- .../src/services/api/hooks/modelsByType.ts | 3 + .../frontend/web/src/services/api/types.ts | 4 + 9 files changed, 785 insertions(+), 21 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts create mode 100644 invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts create mode 100644 invokeai/frontend/web/src/features/queue/store/readiness.test.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1d99f2cae4..3e7e742934 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1316,9 +1316,11 @@ "zImageQwen3Source": "Qwen3 & VAE Source Model", "zImageQwen3SourcePlaceholder": "Required if VAE/Encoder empty", "flux2KleinVae": "VAE (optional)", - "flux2KleinVaePlaceholder": "From main model", + "flux2KleinVaePlaceholder": "From diffusers model", + "flux2KleinVaeNoModelPlaceholder": "No diffusers model available", "flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)", - "flux2KleinQwen3EncoderPlaceholder": "From main model", + "flux2KleinQwen3EncoderPlaceholder": "From diffusers model", + "flux2KleinQwen3EncoderNoModelPlaceholder": "No diffusers model available", "qwenImageComponentSource": "VAE/Encoder Source (Diffusers)", "qwenImageComponentSourcePlaceholder": "Required for GGUF models", "qwenImageQuantization": "Encoder Quantization", @@ -1623,6 +1625,8 @@ "noFLUXVAEModelSelected": "No VAE model selected for FLUX generation", "noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation", "noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation", + "noFlux2KleinVaeModelSelected": "No VAE selected. Non-diffusers FLUX.2 Klein models require a standalone VAE", + "noFlux2KleinQwen3EncoderModelSelected": "No Qwen3 Encoder selected. Non-diffusers FLUX.2 Klein models require a standalone Qwen3 Encoder", "noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder", "noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model", "noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model", diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts new file mode 100644 index 0000000000..7f01becc3d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -0,0 +1,370 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('app/logging/logger', () => ({ + logger: () => ({ + debug: vi.fn(), + }), +})); + +let nextId = 0; +vi.mock('features/controlLayers/konva/util', () => ({ + getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`, +})); + +// --- Flux2 Klein model fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-klein-diffusers', + hash: 'flux2-diff-hash', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const flux2GGUFModel = { + key: 'flux2-klein-gguf', + hash: 'flux2-gguf-hash', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +}; + +const kleinVaeModelFixture = { key: 'klein-vae', name: 'Klein VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3EncoderModelFixture = { + key: 'klein-qwen3', + name: 'Qwen3 4B', + base: 'flux2', + type: 'qwen3_encoder', +}; + +const flux2GGUF9BModel = { + key: 'flux2-klein-gguf-9b', + hash: 'flux2-gguf-9b-hash', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +}; + +const diffusersSourceModelFixture = { + key: 'flux2-source-diffusers', + hash: 'flux2-src-hash', + name: 'FLUX.2 Klein 4B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const diffusers9BSourceModelFixture = { + key: 'flux2-source-diffusers-9b', + hash: 'flux2-src-9b-hash', + name: 'FLUX.2 Klein 9B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_9b', +}; + +// --- Mutable state --- + +let model: Record = { ...flux2DiffusersModel }; +let kleinVaeModel: Record | null = null; +let kleinQwen3EncoderModel: Record | null = null; +let diffusersModels: Record[] = []; + +vi.mock('features/controlLayers/store/paramsSlice', () => ({ + selectMainModelConfig: vi.fn(() => model), + selectParamsSlice: vi.fn(() => ({ + guidance: 4, + steps: 20, + fluxScheduler: 'euler', + fluxDypePreset: 'off', + fluxDypeScale: 2.0, + fluxDypeExponent: 2.0, + fluxVAE: null, + t5EncoderModel: null, + clipEmbedModel: null, + })), + selectKleinVaeModel: vi.fn(() => kleinVaeModel), + selectKleinQwen3EncoderModel: vi.fn(() => kleinQwen3EncoderModel), +})); + +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ + selectRefImagesSlice: vi.fn(() => ({ + entities: [], + })), +})); + +vi.mock('features/controlLayers/store/selectors', () => ({ + selectCanvasMetadata: vi.fn(() => ({})), + selectCanvasSlice: vi.fn(() => ({})), +})); + +vi.mock('features/controlLayers/store/types', () => ({ + isFlux2ReferenceImageConfig: vi.fn(() => false), + isFluxKontextReferenceImageConfig: vi.fn(() => false), +})); + +vi.mock('features/controlLayers/store/validators', () => ({ + getGlobalReferenceImageWarnings: vi.fn(() => []), +})); + +vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({ + addFlux2KleinLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({ + addFLUXFill: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({ + addFLUXLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({ + addFLUXReduxes: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ + addImageToImage: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ + addInpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ + addNSFWChecker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ + addOutpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ + addRegions: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ + addTextToImage: vi.fn(({ l2i }) => l2i), +})); + +vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ + addWatermarker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({ + addControlLoRA: vi.fn(), + addControlNets: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({ + addIPAdapters: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ + selectCanvasOutputFields: vi.fn(() => ({})), + selectPresetModifiedPrompts: vi.fn(() => ({ + positive: 'a prompt', + negative: '', + })), +})); + +vi.mock('features/ui/store/uiSelectors', () => ({ + selectActiveTab: vi.fn(() => 'generation'), +})); + +vi.mock('services/api/hooks/modelsByType', () => ({ + selectFlux2DiffusersModels: vi.fn(() => diffusersModels), +})); + +vi.mock('services/api/types', async () => { + const actual = await vi.importActual('services/api/types'); + return { + ...actual, + isNonRefinerMainModelConfig: vi.fn(() => true), + }; +}); + +import { buildFLUXGraph } from './buildFLUXGraph'; + +const buildGraphArg = () => ({ + generationMode: 'txt2img' as const, + manager: null, + state: { + system: { + shouldUseNSFWChecker: false, + shouldUseWatermarker: false, + }, + } as never, +}); + +/** Find the flux2_klein_model_loader node in the graph. */ +const getLoaderNode = async () => { + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const loaderEntry = Object.entries(graph.nodes).find(([id]) => id.startsWith('flux2_klein_model_loader:')); + return loaderEntry?.[1] as Record | undefined; +}; + +describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { + afterEach(() => { + nextId = 0; + model = { ...flux2DiffusersModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = null; + diffusersModels = []; + }); + + it('does not set qwen3_source_model when main model is diffusers', async () => { + model = { ...flux2DiffusersModel }; + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when main model is GGUF and a diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toEqual({ + key: diffusersSourceModelFixture.key, + hash: diffusersSourceModelFixture.hash, + name: diffusersSourceModelFixture.name, + base: diffusersSourceModelFixture.base, + type: diffusersSourceModelFixture.type, + }); + }); + + it('does not set qwen3_source_model when main model is GGUF but standalone VAE and Qwen3 are both selected', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('does not set qwen3_source_model when main model is GGUF and no diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = []; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when only VAE is selected but Qwen3 is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = null; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('sets qwen3_source_model when only Qwen3 is selected but VAE is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('passes standalone vae_model and qwen3_encoder_model when selected', async () => { + model = { ...flux2DiffusersModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.vae_model).toEqual(kleinVaeModelFixture); + expect(loader!.qwen3_encoder_model).toEqual(kleinQwen3EncoderModelFixture); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + describe('variant matching', () => { + it('selects a variant-matching diffusers model when multiple are available', async () => { + model = { ...flux2GGUF9BModel }; + diffusersModels = [diffusersSourceModelFixture, diffusers9BSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should pick the 9B diffusers model, not the 4B + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusers9BSourceModelFixture.key })); + }); + + it('falls back to any diffusers model for VAE when standalone Qwen3 is selected but no variant match', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + // Only 4B diffusers available, no 9B — but Qwen3 is already provided standalone + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should use the 4B diffusers model just for VAE extraction + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusersSourceModelFixture.key })); + }); + + it('does not set qwen3_source_model when GGUF 9B with only 4B diffusers available and no standalone Qwen3', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = null; + // Only 4B diffusers available — wrong variant for Qwen3, no standalone Qwen3 selected + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should NOT use the 4B diffusers since it has the wrong Qwen3 encoder + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + }); + + describe('graph structure', () => { + it('uses flux2_klein_model_loader for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_model_loader:'))).toBe(true); + }); + + it('uses flux2_vae_decode for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_vae_decode:'))).toBe(true); + }); + + it('uses flux2_klein_text_encoder for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_text_encoder:'))).toBe(true); + }); + + it('uses flux2_denoise for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeTypes = Object.values(graph.nodes).map((n) => n.type); + expect(nodeTypes).toContain('flux2_denoise'); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index ba27e5dbf6..407c921421 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -10,7 +10,8 @@ import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlic import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isFlux2ReferenceImageConfig, isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types'; import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; -import { zImageField } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { zImageField, zModelIdentifierField } from 'features/nodes/types/common'; import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux2KleinLoRAs'; import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill'; import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; @@ -26,8 +27,10 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { t } from 'i18next'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -141,7 +144,23 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); + const sourceModel = variantMatch ?? (kleinQwen3EncoderModel ? diffusersModels[0] : undefined); + if (sourceModel) { + qwen3SourceModel = zModelIdentifierField.parse(sourceModel); + } + } + modelLoader = g.addNode({ type: 'flux2_klein_model_loader', id: getPrefixedId('flux2_klein_model_loader'), @@ -149,6 +168,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise { const dispatch = useAppDispatch(); const { t } = useTranslation(); const kleinVaeModel = useAppSelector(selectKleinVaeModel); + const mainModelConfig = useAppSelector(selectMainModelConfig); const [modelConfigs, { isLoading }] = useFlux2VAEModels(); + const [diffusersModels] = useFlux2DiffusersModels(); const _onChange = useCallback( (model: VAEModelConfig | null) => { @@ -42,6 +45,11 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { isLoading, }); + const hasDiffusersSource = mainModelConfig?.format === 'diffusers' || diffusersModels.length > 0; + const placeholder = hasDiffusersSource + ? t('modelManager.flux2KleinVaePlaceholder') + : t('modelManager.flux2KleinVaeNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinVae')} @@ -51,7 +59,7 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinVaePlaceholder')} + placeholder={placeholder} /> ); @@ -59,15 +67,6 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { ParamFlux2KleinVaeModelSelect.displayName = 'ParamFlux2KleinVaeModelSelect'; -/** - * Maps FLUX.2 Klein variants to compatible Qwen3 encoder variants - */ -const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { - klein_4b: 'qwen3_4b', - klein_9b: 'qwen3_8b', - klein_9b_base: 'qwen3_8b', -}; - /** * FLUX.2 Klein Qwen3 Encoder Model Select * Selects a Qwen3 text encoder model for FLUX.2 Klein @@ -79,6 +78,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { const kleinQwen3EncoderModel = useAppSelector(selectKleinQwen3EncoderModel); const mainModelConfig = useAppSelector(selectMainModelConfig); const [allModelConfigs, { isLoading }] = useQwen3EncoderModels(); + const [diffusersModels] = useFlux2DiffusersModels(); // Filter Qwen3 encoders based on the main model's variant const modelConfigs = useMemo(() => { @@ -112,6 +112,20 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { isLoading, }); + // Qwen3 encoder requires a Qwen3-compatible diffusers model (variants that share the same Qwen3 encoder). + const hasMatchingDiffusersSource = + mainModelConfig?.format === 'diffusers' || + diffusersModels.some( + (m) => + 'variant' in m && + mainModelConfig && + 'variant' in mainModelConfig && + isFlux2KleinQwen3Compatible(m.variant, mainModelConfig.variant) + ); + const placeholder = hasMatchingDiffusersSource + ? t('modelManager.flux2KleinQwen3EncoderPlaceholder') + : t('modelManager.flux2KleinQwen3EncoderNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinQwen3Encoder')} @@ -121,7 +135,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinQwen3EncoderPlaceholder')} + placeholder={placeholder} /> ); diff --git a/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts new file mode 100644 index 0000000000..b9508a4f82 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts @@ -0,0 +1,24 @@ +/** + * Maps a FLUX.2 Klein main-model variant to the Qwen3 encoder variant it uses. + * Multiple Klein variants can share the same Qwen3 variant (e.g. `klein_9b` and + * `klein_9b_base` both use `qwen3_8b`), so two different Klein variants can be + * Qwen3-compatible sources for each other. + */ +export const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { + klein_4b: 'qwen3_4b', + klein_9b: 'qwen3_8b', + klein_9b_base: 'qwen3_8b', +}; + +/** + * Returns true if two Klein variants share the same Qwen3 encoder and can therefore + * be used as a Qwen3 source for each other. + */ +export const isFlux2KleinQwen3Compatible = (variantA: unknown, variantB: unknown): boolean => { + if (typeof variantA !== 'string' || typeof variantB !== 'string') { + return false; + } + const qwen3A = KLEIN_TO_QWEN3_VARIANT_MAP[variantA]; + const qwen3B = KLEIN_TO_QWEN3_VARIANT_MAP[variantB]; + return qwen3A !== undefined && qwen3A === qwen3B; +}; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts new file mode 100644 index 0000000000..632006050e --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -0,0 +1,272 @@ +import { describe, expect, it, vi } from 'vitest'; + +vi.mock('features/dynamicPrompts/util/getShouldProcessPrompt', () => ({ + getShouldProcessPrompt: vi.fn(() => false), +})); + +vi.mock('i18next', () => ({ + default: { + t: (key: string) => key, + }, +})); + +import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types'; +import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; +import type { MainModelConfig } from 'services/api/types'; + +import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness'; + +// --- Fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-diff', + hash: 'h', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF4BModel = { + key: 'flux2-gguf-4b', + hash: 'h', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF9BModel = { + key: 'flux2-gguf-9b', + hash: 'h', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +} as unknown as MainModelConfig; + +const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; + +const baseDynamicPrompts: DynamicPromptsState = { + _version: 1, + maxPrompts: 100, + combinatorial: false, + prompts: ['test prompt'], + parsingError: undefined, + isError: false, + isLoading: false, + seedBehaviour: 'PER_PROMPT', +}; + +const baseRefImages: RefImagesState = { + entities: [], + ipAdapters: { entities: [], ids: [] }, +} as unknown as RefImagesState; + +const baseParams = { + positivePrompt: 'test', + kleinVaeModel: null, + kleinQwen3EncoderModel: null, +} as unknown as ParamsState; + +// --- Helpers --- + +const buildGenerateTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const buildCanvasTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + canvas: { + bbox: { + scaleMethod: 'none', + rect: { width: 1024, height: 1024 }, + scaledSize: { width: 1024, height: 1024 }, + }, + controlLayers: { entities: [] }, + regionalGuidance: { entities: [] }, + rasterLayers: { entities: [] }, + inpaintMasks: { entities: [] }, + }, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + canvasIsFiltering: false, + canvasIsTransforming: false, + canvasIsRasterizing: false, + canvasIsCompositing: false, + canvasIsSelectingObject: false, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const hasFlux2VaeReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinVaeModelSelected')); + +const hasFlux2Qwen3Reason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected')); + +// --- Tests --- + +describe('FLUX.2 Klein readiness checks – generate tab', () => { + it('no errors when main model is diffusers (VAE/Qwen3 extracted from it)', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2DiffusersModel })); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no diffusers source and no standalone models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2GGUF4BModel })); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for Qwen3 when GGUF model with standalone VAE but no Qwen3 and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinVaeModel: kleinVaeModel }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for VAE when GGUF model with standalone Qwen3 but no VAE and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinQwen3EncoderModel: kleinQwen3Model }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B model with only a 4B diffusers source (variant mismatch)', () => { + // User has Klein 9B GGUF selected, only a 4B diffusers model installed. + // VAE is shared across variants so it's ok. Qwen3 encoder differs, so it's not ok. + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF 9B model with variant-matching diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); +}); + +describe('FLUX.2 Klein readiness checks – canvas tab', () => { + it('no errors when main model is diffusers', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2GGUF4BModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B with variant-mismatched diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 67dfe3141c..5802a2aed5 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -33,12 +33,14 @@ import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/ import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue'; import type { UpscaleState } from 'features/parameters/store/upscaleSlice'; import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { getGridSize } from 'features/parameters/util/optimalDimension'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import type { TabName } from 'features/ui/store/uiTypes'; import i18n from 'i18next'; import { atom, computed } from 'nanostores'; import { useEffect } from 'react'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { MainOrExternalModelConfig } from 'services/api/types'; import { isExternalApiModelConfig } from 'services/api/types'; import { $isConnected } from 'services/events/stores'; @@ -109,6 +111,12 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { } = arg; if (tab === 'generate') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueGenerateTab({ isConnected, model, @@ -116,10 +124,18 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { refImages, dynamicPrompts, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'canvas') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueCanvasTab({ isConnected, model, @@ -133,6 +149,8 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { canvasIsCompositing, canvasIsSelectingObject, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'workflows') { @@ -220,15 +238,26 @@ export const useReadinessWatcher = () => { const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') }); -const getReasonsWhyCannotEnqueueGenerateTab = (arg: { +export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; params: ParamsState; refImages: RefImagesState; loras: LoRA[]; dynamicPrompts: DynamicPromptsState; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { - const { isConnected, model, params, refImages, loras, dynamicPrompts } = arg; + const { + isConnected, + model, + params, + refImages, + loras, + dynamicPrompts, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, + } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -260,7 +289,17 @@ const getReasonsWhyCannotEnqueueGenerateTab = (arg: { } } - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + if (model?.base === 'flux2' && model.format !== 'diffusers') { + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } if (model?.base === 'qwen-image' && model.format === 'gguf_quantized') { if (!params.qwenImageComponentSource) { @@ -452,7 +491,7 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: { return reasons; }; -const getReasonsWhyCannotEnqueueCanvasTab = (arg: { +export const getReasonsWhyCannotEnqueueCanvasTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; canvas: CanvasState; @@ -465,6 +504,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing: boolean; canvasIsCompositing: boolean; canvasIsSelectingObject: boolean; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { const { isConnected, @@ -479,6 +520,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing, canvasIsCompositing, canvasIsSelectingObject, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -571,7 +614,17 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { } if (model?.base === 'flux2') { - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (model.format !== 'diffusers') { + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } const { bbox } = canvas; const gridSize = getGridSize('flux'); // FLUX.2 uses same grid size as FLUX.1 diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 55746e5294..2496c06ed0 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -18,6 +18,7 @@ import { isControlNetModelConfig, isExternalApiModelConfig, isFlux1VAEModelConfig, + isFlux2DiffusersMainModelConfig, isFlux2VAEModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig, @@ -101,6 +102,7 @@ export const useFlux2VAEModels = () => buildModelsHook(isFlux2VAEModelConfig)(); export const useAnimaVAEModels = () => buildModelsHook(isAnimaVAEModelConfig)(); export const useAnimaQwen3EncoderModels = () => buildModelsHook(isAnimaQwen3EncoderModelConfig)(); export const useZImageDiffusersModels = () => buildModelsHook(isZImageDiffusersMainModelConfig)(); +export const useFlux2DiffusersModels = () => buildModelsHook(isFlux2DiffusersMainModelConfig)(); export const useQwenImageDiffusersModels = () => buildModelsHook(isQwenImageDiffusersMainModelConfig)(); export const useQwen3EncoderModels = () => buildModelsHook(isQwen3EncoderModelConfig)(); export const useGlobalReferenceImageModels = buildModelsHook( @@ -140,6 +142,7 @@ export const selectAnimaQwen3EncoderModels = buildModelsSelector(isAnimaQwen3Enc export const selectQwen3EncoderModels = buildModelsSelector(isQwen3EncoderModelConfig); export const selectQwenImageDiffusersModels = buildModelsSelector(isQwenImageDiffusersMainModelConfig); export const selectZImageDiffusersModels = buildModelsSelector(isZImageDiffusersMainModelConfig); +export const selectFlux2DiffusersModels = buildModelsSelector(isFlux2DiffusersMainModelConfig); export const selectFluxVAEModels = buildModelsSelector(isFluxVAEModelConfig); export const selectAnimaVAEModels = buildModelsSelector(isAnimaVAEModelConfig); export const selectT5EncoderModels = buildModelsSelector(isT5EncoderModelConfigOrSubmodel); diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 3624d7ef6a..9deefada23 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -457,6 +457,10 @@ export const isZImageDiffusersMainModelConfig = (config: AnyModelConfig): config return config.type === 'main' && config.base === 'z-image' && config.format === 'diffusers'; }; +export const isFlux2DiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base === 'flux2' && config.format === 'diffusers'; +}; + export const isQwenImageDiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'qwen-image' && config.format === 'diffusers'; }; From d7d623e1d5c17dcc64e5fe82205dd2b513099ba9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ufuk=20Sarp=20Sel=C3=A7ok?= Date: Tue, 21 Apr 2026 00:18:25 +0300 Subject: [PATCH 2/2] Fix: Preserve reference image panel state and selection on recall (#9010) * Fix: Preserve reference image panel state and selection on recall Made-with: Cursor * chore: apply Prettier to refImagesSlice Made-with: Cursor * fix: refine ref image recall selection behavior Simplify image-name fingerprint fallback, add an explicit guard for open-panel with null selection, and document the acceptable empty-config collision tradeoff. Made-with: Cursor --------- Co-authored-by: Alexander Eichhorn --- .../controlLayers/store/refImagesSlice.ts | 47 +++++++++++++++++-- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index 2b7c0f7d17..1ea7626290 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -35,6 +35,15 @@ type PayloadActionWithId = T extends void } & T >; +/** Fingerprint used to match the same reference image entry after recall when ids are regenerated. */ +/** Empty configs of the same type may collide; the worst case is selecting an equivalent empty entity. */ +const getRefImageRecallMatchKey = (entity: RefImageState): string => { + const { config } = entity; + const imageName = config.image?.original.image.image_name ?? ''; + const modelKey = 'model' in config && config.model ? config.model.key : ''; + return `${config.type}\0${modelKey}\0${imageName}`; +}; + const slice = createSlice({ name: 'refImages', initialState: getInitialRefImagesState(), @@ -54,13 +63,41 @@ const slice = createSlice({ }, refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => { const { entities, replace } = action.payload; - if (replace) { - state.entities = entities; - state.isPanelOpen = false; - state.selectedEntityId = null; - } else { + if (!replace) { state.entities.push(...entities); + return; } + const wasPanelOpen = state.isPanelOpen; + const previousSelectedId = state.selectedEntityId; + let previousEntity: RefImageState | null = null; + if (previousSelectedId !== null) { + previousEntity = state.entities.find((e) => e.id === previousSelectedId) ?? null; + } + state.entities = entities; + if (entities.length === 0) { + state.selectedEntityId = null; + state.isPanelOpen = false; + return; + } + if (!wasPanelOpen) { + state.selectedEntityId = null; + return; + } + const firstEntity = entities[0]; + assert(firstEntity); + if (previousSelectedId === null) { + // Open panel must have a selection; otherwise, fall back to the first entity. + state.selectedEntityId = firstEntity.id; + return; + } + if (previousSelectedId !== null && entities.some((e) => e.id === previousSelectedId)) { + state.selectedEntityId = previousSelectedId; + return; + } + const previousKey = previousEntity ? getRefImageRecallMatchKey(previousEntity) : null; + const matched = + previousKey !== null ? entities.find((e) => getRefImageRecallMatchKey(e) === previousKey) : undefined; + state.selectedEntityId = matched?.id ?? firstEntity.id; }, refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => { const { id, croppableImage } = action.payload;