From c550ce31a9487e7c6d1ed18efa24632ac2804693 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 20 Apr 2026 16:58:07 -0400 Subject: [PATCH 1/7] fix(ui): FLUX.2 Klein VAE/Qwen3 readiness checks and diffusers source auto-detection (#9041) * fix(ui): FLUX.2 Klein VAE/Qwen3 readiness checks and diffusers source auto-detection Fix several issues with FLUX.2 Klein model handling: 1. Add readiness validation for non-diffusers Klein models so the invoke button is disabled when required VAE/Qwen3 submodels are missing. 2. Auto-detect installed diffusers flux2 models and pass them as qwen3_source_model in the graph builder, so GGUF/safetensors models can extract VAE and encoder from an available diffusers model. 3. Use variant-aware matching so Klein 9B models pick a 9B diffusers source (not 4B), preventing Qwen3 encoder dimension mismatches. 4. Change placeholder text from "From main model" to "From diffusers model" or "No diffusers model available" depending on availability. 5. Export readiness check functions and add comprehensive tests for both the graph builder and readiness logic. Co-Authored-By: Claude Opus 4.6 (1M context) * Chore Fix merge * fix(ui): unify FLUX.2 Klein Qwen3 variant matching Extract KLEIN_TO_QWEN3_VARIANT_MAP and isFlux2KleinQwen3Compatible into features/parameters/util/flux2Klein so UI placeholder, readiness check, and graph builder share one rule. Accepts klein_9b and klein_9b_base as mutual Qwen3 sources (both use qwen3_8b) and guards against undefined === undefined false positives. Use zModelIdentifierField.parse for qwen3_source_model construction in buildFLUXGraph, matching the pattern used for Z-Image. --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Alexander Eichhorn --- invokeai/frontend/web/public/locales/en.json | 8 +- .../graph/generation/buildFLUXGraph.test.ts | 370 ++++++++++++++++++ .../util/graph/generation/buildFLUXGraph.ts | 24 +- .../Advanced/ParamFlux2KleinModelSelect.tsx | 38 +- .../features/parameters/util/flux2Klein.ts | 24 ++ .../features/queue/store/readiness.test.ts | 272 +++++++++++++ .../web/src/features/queue/store/readiness.ts | 63 ++- .../src/services/api/hooks/modelsByType.ts | 3 + .../frontend/web/src/services/api/types.ts | 4 + 9 files changed, 785 insertions(+), 21 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts create mode 100644 invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts create mode 100644 invokeai/frontend/web/src/features/queue/store/readiness.test.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1d99f2cae4..3e7e742934 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1316,9 +1316,11 @@ "zImageQwen3Source": "Qwen3 & VAE Source Model", "zImageQwen3SourcePlaceholder": "Required if VAE/Encoder empty", "flux2KleinVae": "VAE (optional)", - "flux2KleinVaePlaceholder": "From main model", + "flux2KleinVaePlaceholder": "From diffusers model", + "flux2KleinVaeNoModelPlaceholder": "No diffusers model available", "flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)", - "flux2KleinQwen3EncoderPlaceholder": "From main model", + "flux2KleinQwen3EncoderPlaceholder": "From diffusers model", + "flux2KleinQwen3EncoderNoModelPlaceholder": "No diffusers model available", "qwenImageComponentSource": "VAE/Encoder Source (Diffusers)", "qwenImageComponentSourcePlaceholder": "Required for GGUF models", "qwenImageQuantization": "Encoder Quantization", @@ -1623,6 +1625,8 @@ "noFLUXVAEModelSelected": "No VAE model selected for FLUX generation", "noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation", "noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation", + "noFlux2KleinVaeModelSelected": "No VAE selected. Non-diffusers FLUX.2 Klein models require a standalone VAE", + "noFlux2KleinQwen3EncoderModelSelected": "No Qwen3 Encoder selected. Non-diffusers FLUX.2 Klein models require a standalone Qwen3 Encoder", "noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder", "noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model", "noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model", diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts new file mode 100644 index 0000000000..7f01becc3d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -0,0 +1,370 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('app/logging/logger', () => ({ + logger: () => ({ + debug: vi.fn(), + }), +})); + +let nextId = 0; +vi.mock('features/controlLayers/konva/util', () => ({ + getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`, +})); + +// --- Flux2 Klein model fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-klein-diffusers', + hash: 'flux2-diff-hash', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const flux2GGUFModel = { + key: 'flux2-klein-gguf', + hash: 'flux2-gguf-hash', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +}; + +const kleinVaeModelFixture = { key: 'klein-vae', name: 'Klein VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3EncoderModelFixture = { + key: 'klein-qwen3', + name: 'Qwen3 4B', + base: 'flux2', + type: 'qwen3_encoder', +}; + +const flux2GGUF9BModel = { + key: 'flux2-klein-gguf-9b', + hash: 'flux2-gguf-9b-hash', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +}; + +const diffusersSourceModelFixture = { + key: 'flux2-source-diffusers', + hash: 'flux2-src-hash', + name: 'FLUX.2 Klein 4B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const diffusers9BSourceModelFixture = { + key: 'flux2-source-diffusers-9b', + hash: 'flux2-src-9b-hash', + name: 'FLUX.2 Klein 9B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_9b', +}; + +// --- Mutable state --- + +let model: Record = { ...flux2DiffusersModel }; +let kleinVaeModel: Record | null = null; +let kleinQwen3EncoderModel: Record | null = null; +let diffusersModels: Record[] = []; + +vi.mock('features/controlLayers/store/paramsSlice', () => ({ + selectMainModelConfig: vi.fn(() => model), + selectParamsSlice: vi.fn(() => ({ + guidance: 4, + steps: 20, + fluxScheduler: 'euler', + fluxDypePreset: 'off', + fluxDypeScale: 2.0, + fluxDypeExponent: 2.0, + fluxVAE: null, + t5EncoderModel: null, + clipEmbedModel: null, + })), + selectKleinVaeModel: vi.fn(() => kleinVaeModel), + selectKleinQwen3EncoderModel: vi.fn(() => kleinQwen3EncoderModel), +})); + +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ + selectRefImagesSlice: vi.fn(() => ({ + entities: [], + })), +})); + +vi.mock('features/controlLayers/store/selectors', () => ({ + selectCanvasMetadata: vi.fn(() => ({})), + selectCanvasSlice: vi.fn(() => ({})), +})); + +vi.mock('features/controlLayers/store/types', () => ({ + isFlux2ReferenceImageConfig: vi.fn(() => false), + isFluxKontextReferenceImageConfig: vi.fn(() => false), +})); + +vi.mock('features/controlLayers/store/validators', () => ({ + getGlobalReferenceImageWarnings: vi.fn(() => []), +})); + +vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({ + addFlux2KleinLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({ + addFLUXFill: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({ + addFLUXLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({ + addFLUXReduxes: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ + addImageToImage: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ + addInpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ + addNSFWChecker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ + addOutpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ + addRegions: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ + addTextToImage: vi.fn(({ l2i }) => l2i), +})); + +vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ + addWatermarker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({ + addControlLoRA: vi.fn(), + addControlNets: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({ + addIPAdapters: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ + selectCanvasOutputFields: vi.fn(() => ({})), + selectPresetModifiedPrompts: vi.fn(() => ({ + positive: 'a prompt', + negative: '', + })), +})); + +vi.mock('features/ui/store/uiSelectors', () => ({ + selectActiveTab: vi.fn(() => 'generation'), +})); + +vi.mock('services/api/hooks/modelsByType', () => ({ + selectFlux2DiffusersModels: vi.fn(() => diffusersModels), +})); + +vi.mock('services/api/types', async () => { + const actual = await vi.importActual('services/api/types'); + return { + ...actual, + isNonRefinerMainModelConfig: vi.fn(() => true), + }; +}); + +import { buildFLUXGraph } from './buildFLUXGraph'; + +const buildGraphArg = () => ({ + generationMode: 'txt2img' as const, + manager: null, + state: { + system: { + shouldUseNSFWChecker: false, + shouldUseWatermarker: false, + }, + } as never, +}); + +/** Find the flux2_klein_model_loader node in the graph. */ +const getLoaderNode = async () => { + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const loaderEntry = Object.entries(graph.nodes).find(([id]) => id.startsWith('flux2_klein_model_loader:')); + return loaderEntry?.[1] as Record | undefined; +}; + +describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { + afterEach(() => { + nextId = 0; + model = { ...flux2DiffusersModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = null; + diffusersModels = []; + }); + + it('does not set qwen3_source_model when main model is diffusers', async () => { + model = { ...flux2DiffusersModel }; + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when main model is GGUF and a diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toEqual({ + key: diffusersSourceModelFixture.key, + hash: diffusersSourceModelFixture.hash, + name: diffusersSourceModelFixture.name, + base: diffusersSourceModelFixture.base, + type: diffusersSourceModelFixture.type, + }); + }); + + it('does not set qwen3_source_model when main model is GGUF but standalone VAE and Qwen3 are both selected', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('does not set qwen3_source_model when main model is GGUF and no diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = []; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when only VAE is selected but Qwen3 is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = null; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('sets qwen3_source_model when only Qwen3 is selected but VAE is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('passes standalone vae_model and qwen3_encoder_model when selected', async () => { + model = { ...flux2DiffusersModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.vae_model).toEqual(kleinVaeModelFixture); + expect(loader!.qwen3_encoder_model).toEqual(kleinQwen3EncoderModelFixture); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + describe('variant matching', () => { + it('selects a variant-matching diffusers model when multiple are available', async () => { + model = { ...flux2GGUF9BModel }; + diffusersModels = [diffusersSourceModelFixture, diffusers9BSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should pick the 9B diffusers model, not the 4B + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusers9BSourceModelFixture.key })); + }); + + it('falls back to any diffusers model for VAE when standalone Qwen3 is selected but no variant match', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + // Only 4B diffusers available, no 9B — but Qwen3 is already provided standalone + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should use the 4B diffusers model just for VAE extraction + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusersSourceModelFixture.key })); + }); + + it('does not set qwen3_source_model when GGUF 9B with only 4B diffusers available and no standalone Qwen3', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = null; + // Only 4B diffusers available — wrong variant for Qwen3, no standalone Qwen3 selected + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should NOT use the 4B diffusers since it has the wrong Qwen3 encoder + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + }); + + describe('graph structure', () => { + it('uses flux2_klein_model_loader for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_model_loader:'))).toBe(true); + }); + + it('uses flux2_vae_decode for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_vae_decode:'))).toBe(true); + }); + + it('uses flux2_klein_text_encoder for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_text_encoder:'))).toBe(true); + }); + + it('uses flux2_denoise for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeTypes = Object.values(graph.nodes).map((n) => n.type); + expect(nodeTypes).toContain('flux2_denoise'); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index ba27e5dbf6..407c921421 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -10,7 +10,8 @@ import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlic import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isFlux2ReferenceImageConfig, isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types'; import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; -import { zImageField } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { zImageField, zModelIdentifierField } from 'features/nodes/types/common'; import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux2KleinLoRAs'; import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill'; import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; @@ -26,8 +27,10 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { t } from 'i18next'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -141,7 +144,23 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); + const sourceModel = variantMatch ?? (kleinQwen3EncoderModel ? diffusersModels[0] : undefined); + if (sourceModel) { + qwen3SourceModel = zModelIdentifierField.parse(sourceModel); + } + } + modelLoader = g.addNode({ type: 'flux2_klein_model_loader', id: getPrefixedId('flux2_klein_model_loader'), @@ -149,6 +168,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise { const dispatch = useAppDispatch(); const { t } = useTranslation(); const kleinVaeModel = useAppSelector(selectKleinVaeModel); + const mainModelConfig = useAppSelector(selectMainModelConfig); const [modelConfigs, { isLoading }] = useFlux2VAEModels(); + const [diffusersModels] = useFlux2DiffusersModels(); const _onChange = useCallback( (model: VAEModelConfig | null) => { @@ -42,6 +45,11 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { isLoading, }); + const hasDiffusersSource = mainModelConfig?.format === 'diffusers' || diffusersModels.length > 0; + const placeholder = hasDiffusersSource + ? t('modelManager.flux2KleinVaePlaceholder') + : t('modelManager.flux2KleinVaeNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinVae')} @@ -51,7 +59,7 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinVaePlaceholder')} + placeholder={placeholder} /> ); @@ -59,15 +67,6 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { ParamFlux2KleinVaeModelSelect.displayName = 'ParamFlux2KleinVaeModelSelect'; -/** - * Maps FLUX.2 Klein variants to compatible Qwen3 encoder variants - */ -const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { - klein_4b: 'qwen3_4b', - klein_9b: 'qwen3_8b', - klein_9b_base: 'qwen3_8b', -}; - /** * FLUX.2 Klein Qwen3 Encoder Model Select * Selects a Qwen3 text encoder model for FLUX.2 Klein @@ -79,6 +78,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { const kleinQwen3EncoderModel = useAppSelector(selectKleinQwen3EncoderModel); const mainModelConfig = useAppSelector(selectMainModelConfig); const [allModelConfigs, { isLoading }] = useQwen3EncoderModels(); + const [diffusersModels] = useFlux2DiffusersModels(); // Filter Qwen3 encoders based on the main model's variant const modelConfigs = useMemo(() => { @@ -112,6 +112,20 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { isLoading, }); + // Qwen3 encoder requires a Qwen3-compatible diffusers model (variants that share the same Qwen3 encoder). + const hasMatchingDiffusersSource = + mainModelConfig?.format === 'diffusers' || + diffusersModels.some( + (m) => + 'variant' in m && + mainModelConfig && + 'variant' in mainModelConfig && + isFlux2KleinQwen3Compatible(m.variant, mainModelConfig.variant) + ); + const placeholder = hasMatchingDiffusersSource + ? t('modelManager.flux2KleinQwen3EncoderPlaceholder') + : t('modelManager.flux2KleinQwen3EncoderNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinQwen3Encoder')} @@ -121,7 +135,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinQwen3EncoderPlaceholder')} + placeholder={placeholder} /> ); diff --git a/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts new file mode 100644 index 0000000000..b9508a4f82 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts @@ -0,0 +1,24 @@ +/** + * Maps a FLUX.2 Klein main-model variant to the Qwen3 encoder variant it uses. + * Multiple Klein variants can share the same Qwen3 variant (e.g. `klein_9b` and + * `klein_9b_base` both use `qwen3_8b`), so two different Klein variants can be + * Qwen3-compatible sources for each other. + */ +export const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { + klein_4b: 'qwen3_4b', + klein_9b: 'qwen3_8b', + klein_9b_base: 'qwen3_8b', +}; + +/** + * Returns true if two Klein variants share the same Qwen3 encoder and can therefore + * be used as a Qwen3 source for each other. + */ +export const isFlux2KleinQwen3Compatible = (variantA: unknown, variantB: unknown): boolean => { + if (typeof variantA !== 'string' || typeof variantB !== 'string') { + return false; + } + const qwen3A = KLEIN_TO_QWEN3_VARIANT_MAP[variantA]; + const qwen3B = KLEIN_TO_QWEN3_VARIANT_MAP[variantB]; + return qwen3A !== undefined && qwen3A === qwen3B; +}; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts new file mode 100644 index 0000000000..632006050e --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -0,0 +1,272 @@ +import { describe, expect, it, vi } from 'vitest'; + +vi.mock('features/dynamicPrompts/util/getShouldProcessPrompt', () => ({ + getShouldProcessPrompt: vi.fn(() => false), +})); + +vi.mock('i18next', () => ({ + default: { + t: (key: string) => key, + }, +})); + +import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types'; +import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; +import type { MainModelConfig } from 'services/api/types'; + +import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness'; + +// --- Fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-diff', + hash: 'h', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF4BModel = { + key: 'flux2-gguf-4b', + hash: 'h', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF9BModel = { + key: 'flux2-gguf-9b', + hash: 'h', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +} as unknown as MainModelConfig; + +const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; + +const baseDynamicPrompts: DynamicPromptsState = { + _version: 1, + maxPrompts: 100, + combinatorial: false, + prompts: ['test prompt'], + parsingError: undefined, + isError: false, + isLoading: false, + seedBehaviour: 'PER_PROMPT', +}; + +const baseRefImages: RefImagesState = { + entities: [], + ipAdapters: { entities: [], ids: [] }, +} as unknown as RefImagesState; + +const baseParams = { + positivePrompt: 'test', + kleinVaeModel: null, + kleinQwen3EncoderModel: null, +} as unknown as ParamsState; + +// --- Helpers --- + +const buildGenerateTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const buildCanvasTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + canvas: { + bbox: { + scaleMethod: 'none', + rect: { width: 1024, height: 1024 }, + scaledSize: { width: 1024, height: 1024 }, + }, + controlLayers: { entities: [] }, + regionalGuidance: { entities: [] }, + rasterLayers: { entities: [] }, + inpaintMasks: { entities: [] }, + }, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + canvasIsFiltering: false, + canvasIsTransforming: false, + canvasIsRasterizing: false, + canvasIsCompositing: false, + canvasIsSelectingObject: false, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const hasFlux2VaeReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinVaeModelSelected')); + +const hasFlux2Qwen3Reason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected')); + +// --- Tests --- + +describe('FLUX.2 Klein readiness checks – generate tab', () => { + it('no errors when main model is diffusers (VAE/Qwen3 extracted from it)', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2DiffusersModel })); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no diffusers source and no standalone models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2GGUF4BModel })); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for Qwen3 when GGUF model with standalone VAE but no Qwen3 and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinVaeModel: kleinVaeModel }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for VAE when GGUF model with standalone Qwen3 but no VAE and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinQwen3EncoderModel: kleinQwen3Model }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B model with only a 4B diffusers source (variant mismatch)', () => { + // User has Klein 9B GGUF selected, only a 4B diffusers model installed. + // VAE is shared across variants so it's ok. Qwen3 encoder differs, so it's not ok. + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF 9B model with variant-matching diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); +}); + +describe('FLUX.2 Klein readiness checks – canvas tab', () => { + it('no errors when main model is diffusers', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2GGUF4BModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B with variant-mismatched diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 67dfe3141c..5802a2aed5 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -33,12 +33,14 @@ import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/ import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue'; import type { UpscaleState } from 'features/parameters/store/upscaleSlice'; import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { getGridSize } from 'features/parameters/util/optimalDimension'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import type { TabName } from 'features/ui/store/uiTypes'; import i18n from 'i18next'; import { atom, computed } from 'nanostores'; import { useEffect } from 'react'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { MainOrExternalModelConfig } from 'services/api/types'; import { isExternalApiModelConfig } from 'services/api/types'; import { $isConnected } from 'services/events/stores'; @@ -109,6 +111,12 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { } = arg; if (tab === 'generate') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueGenerateTab({ isConnected, model, @@ -116,10 +124,18 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { refImages, dynamicPrompts, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'canvas') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueCanvasTab({ isConnected, model, @@ -133,6 +149,8 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { canvasIsCompositing, canvasIsSelectingObject, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'workflows') { @@ -220,15 +238,26 @@ export const useReadinessWatcher = () => { const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') }); -const getReasonsWhyCannotEnqueueGenerateTab = (arg: { +export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; params: ParamsState; refImages: RefImagesState; loras: LoRA[]; dynamicPrompts: DynamicPromptsState; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { - const { isConnected, model, params, refImages, loras, dynamicPrompts } = arg; + const { + isConnected, + model, + params, + refImages, + loras, + dynamicPrompts, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, + } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -260,7 +289,17 @@ const getReasonsWhyCannotEnqueueGenerateTab = (arg: { } } - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + if (model?.base === 'flux2' && model.format !== 'diffusers') { + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } if (model?.base === 'qwen-image' && model.format === 'gguf_quantized') { if (!params.qwenImageComponentSource) { @@ -452,7 +491,7 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: { return reasons; }; -const getReasonsWhyCannotEnqueueCanvasTab = (arg: { +export const getReasonsWhyCannotEnqueueCanvasTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; canvas: CanvasState; @@ -465,6 +504,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing: boolean; canvasIsCompositing: boolean; canvasIsSelectingObject: boolean; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { const { isConnected, @@ -479,6 +520,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing, canvasIsCompositing, canvasIsSelectingObject, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -571,7 +614,17 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { } if (model?.base === 'flux2') { - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (model.format !== 'diffusers') { + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } const { bbox } = canvas; const gridSize = getGridSize('flux'); // FLUX.2 uses same grid size as FLUX.1 diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 55746e5294..2496c06ed0 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -18,6 +18,7 @@ import { isControlNetModelConfig, isExternalApiModelConfig, isFlux1VAEModelConfig, + isFlux2DiffusersMainModelConfig, isFlux2VAEModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig, @@ -101,6 +102,7 @@ export const useFlux2VAEModels = () => buildModelsHook(isFlux2VAEModelConfig)(); export const useAnimaVAEModels = () => buildModelsHook(isAnimaVAEModelConfig)(); export const useAnimaQwen3EncoderModels = () => buildModelsHook(isAnimaQwen3EncoderModelConfig)(); export const useZImageDiffusersModels = () => buildModelsHook(isZImageDiffusersMainModelConfig)(); +export const useFlux2DiffusersModels = () => buildModelsHook(isFlux2DiffusersMainModelConfig)(); export const useQwenImageDiffusersModels = () => buildModelsHook(isQwenImageDiffusersMainModelConfig)(); export const useQwen3EncoderModels = () => buildModelsHook(isQwen3EncoderModelConfig)(); export const useGlobalReferenceImageModels = buildModelsHook( @@ -140,6 +142,7 @@ export const selectAnimaQwen3EncoderModels = buildModelsSelector(isAnimaQwen3Enc export const selectQwen3EncoderModels = buildModelsSelector(isQwen3EncoderModelConfig); export const selectQwenImageDiffusersModels = buildModelsSelector(isQwenImageDiffusersMainModelConfig); export const selectZImageDiffusersModels = buildModelsSelector(isZImageDiffusersMainModelConfig); +export const selectFlux2DiffusersModels = buildModelsSelector(isFlux2DiffusersMainModelConfig); export const selectFluxVAEModels = buildModelsSelector(isFluxVAEModelConfig); export const selectAnimaVAEModels = buildModelsSelector(isAnimaVAEModelConfig); export const selectT5EncoderModels = buildModelsSelector(isT5EncoderModelConfigOrSubmodel); diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 3624d7ef6a..9deefada23 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -457,6 +457,10 @@ export const isZImageDiffusersMainModelConfig = (config: AnyModelConfig): config return config.type === 'main' && config.base === 'z-image' && config.format === 'diffusers'; }; +export const isFlux2DiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base === 'flux2' && config.format === 'diffusers'; +}; + export const isQwenImageDiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'qwen-image' && config.format === 'diffusers'; }; From d7d623e1d5c17dcc64e5fe82205dd2b513099ba9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ufuk=20Sarp=20Sel=C3=A7ok?= Date: Tue, 21 Apr 2026 00:18:25 +0300 Subject: [PATCH 2/7] Fix: Preserve reference image panel state and selection on recall (#9010) * Fix: Preserve reference image panel state and selection on recall Made-with: Cursor * chore: apply Prettier to refImagesSlice Made-with: Cursor * fix: refine ref image recall selection behavior Simplify image-name fingerprint fallback, add an explicit guard for open-panel with null selection, and document the acceptable empty-config collision tradeoff. Made-with: Cursor --------- Co-authored-by: Alexander Eichhorn --- .../controlLayers/store/refImagesSlice.ts | 47 +++++++++++++++++-- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index 2b7c0f7d17..1ea7626290 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -35,6 +35,15 @@ type PayloadActionWithId = T extends void } & T >; +/** Fingerprint used to match the same reference image entry after recall when ids are regenerated. */ +/** Empty configs of the same type may collide; the worst case is selecting an equivalent empty entity. */ +const getRefImageRecallMatchKey = (entity: RefImageState): string => { + const { config } = entity; + const imageName = config.image?.original.image.image_name ?? ''; + const modelKey = 'model' in config && config.model ? config.model.key : ''; + return `${config.type}\0${modelKey}\0${imageName}`; +}; + const slice = createSlice({ name: 'refImages', initialState: getInitialRefImagesState(), @@ -54,13 +63,41 @@ const slice = createSlice({ }, refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => { const { entities, replace } = action.payload; - if (replace) { - state.entities = entities; - state.isPanelOpen = false; - state.selectedEntityId = null; - } else { + if (!replace) { state.entities.push(...entities); + return; } + const wasPanelOpen = state.isPanelOpen; + const previousSelectedId = state.selectedEntityId; + let previousEntity: RefImageState | null = null; + if (previousSelectedId !== null) { + previousEntity = state.entities.find((e) => e.id === previousSelectedId) ?? null; + } + state.entities = entities; + if (entities.length === 0) { + state.selectedEntityId = null; + state.isPanelOpen = false; + return; + } + if (!wasPanelOpen) { + state.selectedEntityId = null; + return; + } + const firstEntity = entities[0]; + assert(firstEntity); + if (previousSelectedId === null) { + // Open panel must have a selection; otherwise, fall back to the first entity. + state.selectedEntityId = firstEntity.id; + return; + } + if (previousSelectedId !== null && entities.some((e) => e.id === previousSelectedId)) { + state.selectedEntityId = previousSelectedId; + return; + } + const previousKey = previousEntity ? getRefImageRecallMatchKey(previousEntity) : null; + const matched = + previousKey !== null ? entities.find((e) => getRefImageRecallMatchKey(e) === previousKey) : undefined; + state.selectedEntityId = matched?.id ?? firstEntity.id; }, refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => { const { id, croppableImage } = action.payload; From b2d79dc86c3228f8e19bb8065633dc6f9eff7925 Mon Sep 17 00:00:00 2001 From: skunkworxdark Date: Tue, 21 Apr 2026 01:08:09 +0100 Subject: [PATCH 3/7] feat:(model-manager) add sorting capabilities for models (#9024) * feat(model-manager): add comprehensive sorting capabilities for models dded the ability to sort models in the Model Manager by various attributes including Name, Base, Type, Format, Size, Date Added, and Date Modified. Supports both ascending and descending order. - Backend: Added `order_by` and `direction` query parameters to the ``/api/v1/models`/` listing endpoint. Implemented case-insensitive sorting in the SQLite model records service. - Frontend: Introduced `` UI, updated Redux slices to manage sort state, removed client-side entity adapter sorting to respect server-side ordering, and added i18n localization keys. - Tests: Added test coverage for SQL-based sorting on size and name. * feat(model-manager): add comprehensive sorting capabilities for models dded the ability to sort models in the Model Manager by various attributes including Name, Base, Type, Format, Size, Date Added, and Date Modified. Supports both ascending and descending order. - Backend: Added `order_by` and `direction` query parameters to the ``/api/v1/models`/` listing endpoint. Implemented case-insensitive sorting in the SQLite model records service. - Frontend: Introduced `` UI, updated Redux slices to manage sort state, removed client-side entity adapter sorting to respect server-side ordering, and added i18n localization keys. - Tests: Added test coverage for SQL-based sorting on size and name. * ruff fix * typegen fix * typegen fix - this time without my custom nodes. * another typegen fix * refactor(ui): consolidate model filter and sort controls into a unified menu - Replaced separate `ModelSortControl` and `ModelTypeFilter` components with a single, unified "Filtering" dropdown menu. - Organised filtering options into categorised submenus in the following order: Direction, Sort By, and Model Type. - Enhanced submenu labels to display the currently active selection inline for quick reference. - Improved visual alignment within menus by using hidden checkmarks on unselected items, ensuring consistent indentation across all options. - Resolved styling and linting issues (unused variables, JSX bind warnings) within the new component. * Lint fix * Addresses PR feedback to use translation strings directly within `ORDER_BY_OPTIONS`. Previously, sort keys and their translated labels were maintained in separate constructs (`ORDER_BY_OPTIONS` array and `ORDER_BY_LABELS` map). This refactor converts `ORDER_BY_OPTIONS` into an array of objects containing both the `key` and its corresponding `i18nKey`, creating a single source of truth. This change: - Simplifies the `SortBySubMenu` component by removing the redundant `ORDER_BY_LABELS` lookup map. - Improves maintainability by ensuring developers only need to update one place when adding or modifying sort options. - Reduces the risk of mismatched keys and labels. --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> Co-authored-by: Alexander Eichhorn Co-authored-by: Lincoln Stein --- invokeai/app/api/routers/model_manager.py | 19 +- .../model_records/model_records_base.py | 13 +- .../model_records/model_records_sql.py | 35 ++- invokeai/frontend/web/public/locales/en.json | 9 + .../web/src/common/hooks/useSubMenu.tsx | 11 +- .../web/src/features/modelManagerV2/models.ts | 2 +- .../store/modelManagerV2Slice.ts | 16 ++ .../ModelManagerPanel/ModelFilterMenu.tsx | 231 ++++++++++++++++++ .../subpanels/ModelManagerPanel/ModelList.tsx | 7 +- .../ModelManagerPanel/ModelListNavigation.tsx | 4 +- .../ModelManagerPanel/ModelTypeFilter.tsx | 78 ------ .../web/src/services/api/endpoints/models.ts | 15 +- .../frontend/web/src/services/api/schema.ts | 10 + .../model_records/test_model_records_sql.py | 69 ++++++ 14 files changed, 419 insertions(+), 100 deletions(-) create mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx delete mode 100644 invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index f351be11ad..40d4f48b63 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -26,9 +26,11 @@ from invokeai.app.services.model_install.model_install_common import ModelInstal from invokeai.app.services.model_records import ( InvalidModelException, ModelRecordChanges, + ModelRecordOrderBy, UnknownModelException, ) from invokeai.app.services.orphaned_models import OrphanedModelInfo +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory @@ -159,6 +161,8 @@ async def list_model_records( model_format: Optional[ModelFormat] = Query( default=None, description="Exact match on the format of the model (e.g. 'diffusers')" ), + order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Name, description="The field to order by"), + direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_manager.store @@ -167,12 +171,23 @@ async def list_model_records( for base_model in base_models: found_models.extend( record_store.search_by_attr( - base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + base_model=base_model, + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, ) ) else: found_models.extend( - record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + record_store.search_by_attr( + model_type=model_type, + model_name=model_name, + model_format=model_format, + order_by=order_by, + direction=direction, + ) ) for index, model in enumerate(found_models): found_models[index] = prepare_model_config_for_response(model, ApiDependencies) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 6420949c29..31fbadb3cb 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -11,6 +11,7 @@ from typing import List, Optional, Set, Union from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.util.model_exclude_null import BaseModelExcludeNull from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.external_api import ( @@ -60,6 +61,10 @@ class ModelRecordOrderBy(str, Enum): Base = "base" Name = "name" Format = "format" + Size = "size" + DateAdded = "created_at" + DateModified = "updated_at" + Path = "path" class ModelSummary(BaseModel): @@ -200,7 +205,11 @@ class ModelRecordServiceBase(ABC): @abstractmethod def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" pass @@ -237,6 +246,8 @@ class ModelRecordServiceBase(ABC): base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index edcbba2acd..f104c3855e 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -57,6 +57,7 @@ from invokeai.app.services.model_records.model_records_base import ( UnknownModelException, ) from invokeai.app.services.shared.pagination import PaginatedResults +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType @@ -257,6 +258,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model_type: Optional[ModelType] = None, model_format: Optional[ModelFormat] = None, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -266,18 +268,24 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param model_type: Filter by type of model (optional) :param model_format: Filter by model format (e.g. "diffusers") (optional) :param order_by: Result order + :param direction: Result direction If none of the optional filters are passed, will return all models in the database. """ with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } where_clause: list[str] = [] @@ -301,7 +309,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): SELECT config FROM models {where} - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason; + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason; """, tuple(bindings), ) @@ -357,17 +365,26 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return results def list_models( - self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default + self, + page: int = 0, + per_page: int = 10, + order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default, + direction: SQLiteDirection = SQLiteDirection.Ascending, ) -> PaginatedResults[ModelSummary]: """Return a paginated summary listing of each model in the database.""" with self._db.transaction() as cursor: assert isinstance(order_by, ModelRecordOrderBy) + order_dir = "DESC" if direction == SQLiteDirection.Descending else "ASC" ordering = { - ModelRecordOrderBy.Default: "type, base, name, format", + ModelRecordOrderBy.Default: f"type {order_dir}, base COLLATE NOCASE {order_dir}, name COLLATE NOCASE {order_dir}, format", ModelRecordOrderBy.Type: "type", - ModelRecordOrderBy.Base: "base", - ModelRecordOrderBy.Name: "name", + ModelRecordOrderBy.Base: "base COLLATE NOCASE", + ModelRecordOrderBy.Name: "name COLLATE NOCASE", ModelRecordOrderBy.Format: "format", + ModelRecordOrderBy.Size: "IFNULL(json_extract(config, '$.file_size'), 0)", + ModelRecordOrderBy.DateAdded: "created_at", + ModelRecordOrderBy.DateModified: "updated_at", + ModelRecordOrderBy.Path: "path", } # Lock so that the database isn't updated while we're doing the two queries. @@ -385,7 +402,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): f"""--sql SELECT config FROM models - ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason + ORDER BY {ordering[order_by]} {order_dir} -- using ? to bind doesn't work here for some reason LIMIT ? OFFSET ?; """, diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 3e7e742934..75c5ad6671 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1203,6 +1203,15 @@ "modelType": "Model Type", "modelUpdated": "Model Updated", "modelUpdateFailed": "Model Update Failed", + "sortByName": "Name", + "sortByBase": "Base", + "sortBySize": "Size", + "sortByDateAdded": "Date Added", + "sortByDateModified": "Date Modified", + "sortByPath": "Path", + "sortByType": "Type", + "sortByFormat": "Format", + "sortDefault": "Default", "name": "Name", "externalProvider": "External Provider", "externalCapabilities": "External Capabilities", diff --git a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx index f8ea01909a..4c1bc56e49 100644 --- a/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx +++ b/invokeai/frontend/web/src/common/hooks/useSubMenu.tsx @@ -151,11 +151,18 @@ export const useSubMenu = (): UseSubMenuReturn => { }; }; -export const SubMenuButtonContent = ({ label }: { label: string }) => { +export const SubMenuButtonContent = ({ label, value }: { label: string; value?: string }) => { return ( {label} - + + {value !== undefined && ( + + {value} + + )} + + ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 9cc4ed24d9..7cdba474bb 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -31,7 +31,7 @@ export type ModelCategoryData = { filter: (config: AnyModelConfig) => boolean; }; -export const MODEL_CATEGORIES: Record = { +const MODEL_CATEGORIES: Record = { unknown: { category: 'unknown', i18nKey: 'common.unknown', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts index 44df38d911..91fb1afd4d 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts @@ -25,6 +25,10 @@ const zModelManagerState = z.object({ scanPath: z.string().optional(), shouldInstallInPlace: z.boolean(), selectedModelKeys: z.array(z.string()), + orderBy: z + .enum(['default', 'name', 'type', 'base', 'size', 'created_at', 'updated_at', 'path', 'format']) + .default('name'), + sortDirection: z.enum(['asc', 'desc']).default('asc'), }); type ModelManagerState = z.infer; @@ -38,6 +42,8 @@ const getInitialState = (): ModelManagerState => ({ scanPath: undefined, shouldInstallInPlace: true, selectedModelKeys: [], + orderBy: 'name', + sortDirection: 'asc', }); const slice = createSlice({ @@ -77,6 +83,12 @@ const slice = createSlice({ clearModelSelection: (state) => { state.selectedModelKeys = []; }, + setOrderBy: (state, action: PayloadAction) => { + state.orderBy = action.payload; + }, + setSortDirection: (state, action: PayloadAction) => { + state.sortDirection = action.payload; + }, }, }); @@ -90,6 +102,8 @@ export const { modelSelectionChanged, toggleModelSelection, clearModelSelection, + setOrderBy, + setSortDirection, } = slice.actions; export const modelManagerSliceConfig: SliceConfig = { @@ -119,3 +133,5 @@ export const selectSearchTerm = createModelManagerSelector((mm) => mm.searchTerm export const selectFilteredModelType = createModelManagerSelector((mm) => mm.filteredModelType); export const selectShouldInstallInPlace = createModelManagerSelector((mm) => mm.shouldInstallInPlace); export const selectSelectedModelKeys = createModelManagerSelector((mm) => mm.selectedModelKeys); +export const selectOrderBy = createModelManagerSelector((mm) => mm.orderBy); +export const selectSortDirection = createModelManagerSelector((mm) => mm.sortDirection); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx new file mode 100644 index 0000000000..57dad58f2c --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFilterMenu.tsx @@ -0,0 +1,231 @@ +import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; +import type { ModelCategoryData } from 'features/modelManagerV2/models'; +import { MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; +import { + selectFilteredModelType, + selectOrderBy, + selectSortDirection, + setFilteredModelType, + setOrderBy, + setSortDirection, +} from 'features/modelManagerV2/store/modelManagerV2Slice'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + PiCheckBold, + PiFunnelBold, + PiListBold, + PiSortAscendingBold, + PiSortDescendingBold, + PiWarningBold, +} from 'react-icons/pi'; + +type OrderBy = 'default' | 'name' | 'type' | 'base' | 'size' | 'created_at' | 'updated_at' | 'path' | 'format'; + +const ORDER_BY_OPTIONS: { key: OrderBy; i18nKey: string }[] = [ + { key: 'default', i18nKey: 'modelManager.sortDefault' }, + { key: 'name', i18nKey: 'modelManager.sortByName' }, + { key: 'base', i18nKey: 'modelManager.sortByBase' }, + { key: 'size', i18nKey: 'modelManager.sortBySize' }, + { key: 'created_at', i18nKey: 'modelManager.sortByDateAdded' }, + { key: 'updated_at', i18nKey: 'modelManager.sortByDateModified' }, + { key: 'path', i18nKey: 'modelManager.sortByPath' }, + { key: 'type', i18nKey: 'modelManager.sortByType' }, + { key: 'format', i18nKey: 'modelManager.sortByFormat' }, +]; + +const SortByMenuItem = memo(({ option, label }: { option: OrderBy; label: string }) => { + const dispatch = useAppDispatch(); + const orderBy = useAppSelector(selectOrderBy); + const onClick = useCallback(() => { + dispatch(setOrderBy(option)); + }, [dispatch, option]); + + return ( + : } + > + {label} + + ); +}); +SortByMenuItem.displayName = 'SortByMenuItem'; + +const SortBySubMenu = memo(() => { + const { t } = useTranslation(); + const subMenu = useSubMenu(); + const orderBy = useAppSelector(selectOrderBy); + + const currentSortLabel = useMemo(() => { + const option = ORDER_BY_OPTIONS.find((o) => o.key === orderBy); + if (!option) { + return ''; + } + return t(option.i18nKey); + }, [orderBy, t]); + + return ( + }> + + + + + + {ORDER_BY_OPTIONS.map(({ key, i18nKey }) => ( + + ))} + + + + ); +}); +SortBySubMenu.displayName = 'SortBySubMenu'; + +const DirectionSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const direction = useAppSelector(selectSortDirection); + const subMenu = useSubMenu(); + + const setDirectionAsc = useCallback(() => { + dispatch(setSortDirection('asc')); + }, [dispatch]); + + const setDirectionDesc = useCallback(() => { + dispatch(setSortDirection('desc')); + }, [dispatch]); + + const currentValue = direction === 'asc' ? t('common.ascending', 'Ascending') : t('common.descending', 'Descending'); + + return ( + : } + > + + + + + + : } + > + {t('common.ascending', 'Ascending')} + + : } + > + {t('common.descending', 'Descending')} + + + + + ); +}); +DirectionSubMenu.displayName = 'DirectionSubMenu'; + +const ModelTypeSubMenu = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const subMenu = useSubMenu(); + + const clearModelType = useCallback(() => { + dispatch(setFilteredModelType(null)); + }, [dispatch]); + + const setMissingFilter = useCallback(() => { + dispatch(setFilteredModelType('missing')); + }, [dispatch]); + + const currentValue = useMemo(() => { + if (filteredModelType === null) { + return t('modelManager.allModels'); + } + if (filteredModelType === 'missing') { + return t('modelManager.missingFiles'); + } + const categoryData = MODEL_CATEGORIES_AS_LIST.find((data) => data.category === filteredModelType); + return categoryData ? t(categoryData.i18nKey) : ''; + }, [filteredModelType, t]); + + return ( + }> + + + + + + : } + > + {t('modelManager.allModels')} + + : } + > + + {filteredModelType !== 'missing' && } + {t('modelManager.missingFiles')} + + + {MODEL_CATEGORIES_AS_LIST.map((data) => ( + + ))} + + + + ); +}); +ModelTypeSubMenu.displayName = 'ModelTypeSubMenu'; + +const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const filteredModelType = useAppSelector(selectFilteredModelType); + const onClick = useCallback(() => { + dispatch(setFilteredModelType(data.category)); + }, [data.category, dispatch]); + return ( + : } + > + {t(data.i18nKey)} + + ); +}); +ModelMenuItem.displayName = 'ModelMenuItem'; + +export const ModelFilterMenu = memo(() => { + const { t } = useTranslation(); + + return ( + + }> + {t('common.filtering', 'Filtering')} + + + + + + + + ); +}); + +ModelFilterMenu.displayName = 'ModelFilterMenu'; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index ed49fa2870..033a439bfc 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -8,8 +8,10 @@ import { clearModelSelection, type FilterableModelType, selectFilteredModelType, + selectOrderBy, selectSearchTerm, selectSelectedModelKeys, + selectSortDirection, setSelectedModelKey, } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useCallback, useMemo, useState } from 'react'; @@ -39,6 +41,8 @@ const ModelList = () => { const dispatch = useAppDispatch(); const filteredModelType = useAppSelector(selectFilteredModelType); const searchTerm = useAppSelector(selectSearchTerm); + const orderBy = useAppSelector(selectOrderBy); + const direction = useAppSelector(selectSortDirection); const selectedModelKeys = useAppSelector(selectSelectedModelKeys); const { t } = useTranslation(); const toast = useToast(); @@ -47,7 +51,8 @@ const ModelList = () => { const [isDeleting, setIsDeleting] = useState(false); const [isReidentifying, setIsReidentifying] = useState(false); - const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(); + const queryArgs = useMemo(() => ({ order_by: orderBy, direction: direction.toUpperCase() }), [orderBy, direction]); + const { data: allModelsData, isLoading: isLoadingAll } = useGetModelConfigsQuery(queryArgs); const { data: missingModelsData, isLoading: isLoadingMissing } = useGetMissingModelsQuery(); const [bulkDeleteModels] = useBulkDeleteModelsMutation(); const [bulkReidentifyModels] = useBulkReidentifyModelsMutation(); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx index 78bed8ab83..bbfb88df5c 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListNavigation.tsx @@ -6,8 +6,8 @@ import type { ChangeEventHandler } from 'react'; import { memo, useCallback } from 'react'; import { PiXBold } from 'react-icons/pi'; +import { ModelFilterMenu } from './ModelFilterMenu'; import { ModelListBulkActions } from './ModelListBulkActions'; -import { ModelTypeFilter } from './ModelTypeFilter'; export const ModelListNavigation = memo(() => { const dispatch = useAppDispatch(); @@ -50,7 +50,7 @@ export const ModelListNavigation = memo(() => { - + diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx deleted file mode 100644 index 5aa8e62886..0000000000 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ /dev/null @@ -1,78 +0,0 @@ -import { Button, Flex, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import type { ModelCategoryData } from 'features/modelManagerV2/models'; -import { MODEL_CATEGORIES, MODEL_CATEGORIES_AS_LIST } from 'features/modelManagerV2/models'; -import type { ModelCategoryType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { selectFilteredModelType, setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice'; -import { memo, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import { PiFunnelBold, PiWarningBold } from 'react-icons/pi'; - -const isModelCategoryType = (type: string): type is ModelCategoryType => { - return type in MODEL_CATEGORIES; -}; - -export const ModelTypeFilter = memo(() => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - - const clearModelType = useCallback(() => { - dispatch(setFilteredModelType(null)); - }, [dispatch]); - - const setMissingFilter = useCallback(() => { - dispatch(setFilteredModelType('missing')); - }, [dispatch]); - - const getButtonLabel = () => { - if (filteredModelType === 'missing') { - return t('modelManager.missingFiles'); - } - if (filteredModelType && isModelCategoryType(filteredModelType)) { - return t(MODEL_CATEGORIES[filteredModelType].i18nKey); - } - return t('modelManager.allModels'); - }; - - return ( - - }> - {getButtonLabel()} - - - {t('modelManager.allModels')} - - - - {t('modelManager.missingFiles')} - - - {MODEL_CATEGORIES_AS_LIST.map((data) => ( - - ))} - - - ); -}); - -ModelTypeFilter.displayName = 'ModelTypeFilter'; - -const ModelMenuItem = memo(({ data }: { data: ModelCategoryData }) => { - const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const filteredModelType = useAppSelector(selectFilteredModelType); - const onClick = useCallback(() => { - dispatch(setFilteredModelType(data.category)); - }, [data.category, dispatch]); - return ( - - {t(data.i18nKey)} - - ); -}); -ModelMenuItem.displayName = 'ModelMenuItem'; diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c3d0decd53..f279d46d82 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -111,9 +111,13 @@ type DeleteOrphanedModelsResponse = { errors: Record; }; +type GetModelConfigsArg = { + order_by?: string; + direction?: string; +} | void; + const modelConfigsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, - sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions); @@ -338,8 +342,11 @@ export const modelsApi = api.injectEndpoints({ }, invalidatesTags: ['ModelInstalls'], }), - getModelConfigs: build.query, void>({ - query: () => ({ url: buildModelsUrl() }), + getModelConfigs: build.query, GetModelConfigsArg>({ + query: (arg) => { + const queryStr = arg ? `?${queryString.stringify(arg)}` : ''; + return { url: buildModelsUrl(queryStr) }; + }, providesTags: (result) => { const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }]; if (result) { @@ -498,5 +505,5 @@ export const { useDeleteOrphanedModelsMutation, } = modelsApi; -export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(); +export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(undefined); export const selectMissingModelsQuery = modelsApi.endpoints.getMissingModels.select(); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 6c6e259fd1..e1dd2ad361 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -22997,6 +22997,12 @@ export type components = { */ config_path?: string | null; }; + /** + * ModelRecordOrderBy + * @description The order in which to return model summaries. + * @enum {string} + */ + ModelRecordOrderBy: "default" | "type" | "base" | "name" | "format" | "size" | "created_at" | "updated_at" | "path"; /** ModelRelationshipBatchRequest */ ModelRelationshipBatchRequest: { /** @@ -31525,6 +31531,10 @@ export interface operations { model_name?: string | null; /** @description Exact match on the format of the model (e.g. 'diffusers') */ model_format?: components["schemas"]["ModelFormat"] | null; + /** @description The field to order by */ + order_by?: components["schemas"]["ModelRecordOrderBy"]; + /** @description The direction to order by */ + direction?: components["schemas"]["SQLiteDirection"]; }; header?: never; path?: never; diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 2b6c54d5b0..19a1b74e73 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -11,11 +11,13 @@ from pydantic import ValidationError from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.model_records import ( DuplicateModelException, + ModelRecordOrderBy, ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException, ) from invokeai.app.services.model_records.model_records_base import ModelRecordChanges +from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings from invokeai.backend.model_manager.configs.lora import LoRA_LyCORIS_SDXL_Config from invokeai.backend.model_manager.configs.main import ( @@ -364,6 +366,73 @@ def test_filter_2(store: ModelRecordServiceBase): assert len(matches) == 1 +def test_search_by_attr_sorting(store: ModelRecordServiceSQL): + config1 = Main_Diffusers_SD1_Config( + path="/tmp/config1", + name="alpha", + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, + hash="CONFIG1HASH", + file_size=1000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config2 = Main_Diffusers_SD2_Config( + path="/tmp/config2", + name="beta", + base=BaseModelType.StableDiffusion2, + type=ModelType.Main, + hash="CONFIG2HASH", + file_size=2000, + source="test/source/", + source_type=ModelSourceType.Path, + variant=ModelVariantType.Normal, + prediction_type=SchedulerPredictionType.Epsilon, + repo_variant=ModelRepoVariant.Default, + ) + config3 = VAE_Diffusers_SD1_Config( + path="/tmp/config3", + name="gamma", + base=BaseModelType.StableDiffusion1, + type=ModelType.VAE, + hash="CONFIG3HASH", + file_size=500, + source="test/source/", + source_type=ModelSourceType.Path, + repo_variant=ModelRepoVariant.Default, + ) + for c in config1, config2, config3: + store.add_model(c) + + # Test sorting by Name Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Ascending) + assert len(matches) == 3 + assert matches[0].name == "alpha" + assert matches[1].name == "beta" + assert matches[2].name == "gamma" + + # Test sorting by Name Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Name, direction=SQLiteDirection.Descending) + assert matches[0].name == "gamma" + assert matches[1].name == "beta" + assert matches[2].name == "alpha" + + # Test sorting by Size Ascending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Ascending) + assert matches[0].name == "gamma" # 500 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "beta" # 2000 + + # Test sorting by Size Descending + matches = store.search_by_attr(order_by=ModelRecordOrderBy.Size, direction=SQLiteDirection.Descending) + assert matches[0].name == "beta" # 2000 + assert matches[1].name == "alpha" # 1000 + assert matches[2].name == "gamma" # 500 + + def test_model_record_changes(): # This test guards against some unexpected behaviours from pydantic's union evaluation. See #6035 changes = ModelRecordChanges.model_validate({"default_settings": {"preprocessor": "value"}}) From 3e318614b64b6ce12c4add4d3abe4554066b726d Mon Sep 17 00:00:00 2001 From: "Weblate (bot)" Date: Tue, 21 Apr 2026 02:15:05 +0200 Subject: [PATCH 4/7] translationBot(ui): update translation (Italian) (#9078) Currently translated at 97.2% (2521 of 2592 strings) Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/ Translation: InvokeAI/Web UI Co-authored-by: Riccardo Giovanetti --- invokeai/frontend/web/public/locales/it.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index db0a9a11a6..d823258dbf 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -2660,7 +2660,9 @@ "fitModeCover": "Copri", "smoothingMode": "Modalità di ricampionamento", "smoothingDesc": "Applica un ricampionamento di alta qualità lato backend alla conferma delle trasformazioni.", - "smoothing": "Smussamento" + "smoothing": "Smussamento", + "smoothingModeBilinear": "Bilineare", + "smoothingModeBicubic": "Bicubico" }, "stagingArea": { "next": "Successiva", From e521817aa423f5c9bee117ffbc7a11e502bc187d Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 21 Apr 2026 02:28:10 +0200 Subject: [PATCH 5/7] Fix Z-Image LoRA detection for Kohya and ComfyUI formats (#9007) Add support for Kohya-format Z-Image LoRAs (lora_unet__ prefix) by adding key detection and conversion to dot-notation module paths. Fix ComfyUI-format Z-Image LoRAs being misidentified as main models by ensuring LoRA-specific suffixes (including .alpha) are checked before Z-Image key matching in _has_z_image_keys(). --- .../backend/model_manager/configs/lora.py | 26 ++++++- .../backend/model_manager/configs/main.py | 9 ++- .../z_image_lora_conversion_utils.py | 75 ++++++++++++++++++- 3 files changed, 103 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 88f917d0d3..46606a3c0d 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -714,14 +714,25 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): - diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format) - diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format) - diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale) + - lora_unet__layers_X_attention_to_k.lora_down.weight (Kohya format) """ + from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import ( + is_state_dict_likely_z_image_kohya_lora, + ) + state_dict = mod.load_state_dict() - # Check for Z-Image specific LoRA patterns + # Check for Kohya format first + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return + + # Check for Z-Image specific LoRA patterns (dot-notation formats) has_z_image_lora_keys = state_dict_has_any_keys_starting_with( state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "diffusion_model.context_refiner.", + "diffusion_model.noise_refiner.", "transformer.layers.", # OneTrainer/diffusers prefix variant "base_model.model.transformer.layers.", # PEFT-wrapped variant }, @@ -751,15 +762,26 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): Z-Image uses S3-DiT architecture with layer names like: - diffusion_model.layers.0.attention.to_k.lora_A.weight - diffusion_model.layers.0.feed_forward.w1.lora_A.weight + - lora_unet__layers_0_attention_to_k.lora_down.weight (Kohya format) """ + from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import ( + is_state_dict_likely_z_image_kohya_lora, + ) + state_dict = mod.load_state_dict() - # Check for Z-Image transformer layer patterns + # Check for Kohya format + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return BaseModelType.ZImage + + # Check for Z-Image transformer layer patterns (dot-notation formats) # Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks) has_z_image_keys = state_dict_has_any_keys_starting_with( state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "diffusion_model.context_refiner.", + "diffusion_model.noise_refiner.", "transformer.layers.", # OneTrainer/diffusers prefix variant "base_model.model.transformer.layers.", # PEFT-wrapped variant }, diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 1be349f394..a2f008f41e 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -160,17 +160,20 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool: ".lora_A.weight", ".lora_B.weight", ".dora_scale", + ".alpha", ) + # First pass: check if any key has LoRA suffixes - if so, this is a LoRA not a main model for key in state_dict.keys(): if isinstance(key, int): continue - - # If we find any LoRA-specific keys, this is not a main model if key.endswith(lora_suffixes): return False - # Check for Z-Image specific key prefixes + # Second pass: check for Z-Image specific key parts + for key in state_dict.keys(): + if isinstance(key, int): + continue # Handle both direct keys (cap_embedder.0.weight) and # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight) key_parts = key.split(".") diff --git a/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py index e248f9cfc4..70b10de50d 100644 --- a/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py @@ -1,10 +1,11 @@ """Z-Image LoRA conversion utilities. Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder. -LoRAs for Z-Image typically follow the diffusers PEFT format. +LoRAs for Z-Image typically follow the diffusers PEFT format or Kohya format. """ -from typing import Dict +import re +from typing import Any, Dict import torch @@ -16,6 +17,29 @@ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import ( ) from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +# Regex for Kohya-format Z-Image transformer keys. +# Example keys: +# lora_unet__layers_0_attention_to_k.alpha +# lora_unet__layers_0_attention_to_k.lora_down.weight +# lora_unet__context_refiner_0_feed_forward_w1.lora_up.weight +# lora_unet__noise_refiner_1_attention_to_v.lora_down.weight +Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX = ( + r"lora_unet__(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)" +) + + +def is_state_dict_likely_z_image_kohya_lora(state_dict: dict[str | int, Any]) -> bool: + """Checks if the provided state dict is likely a Z-Image LoRA in Kohya format. + + Kohya Z-Image LoRAs have keys like: + - lora_unet__layers_0_attention_to_k.lora_down.weight + - lora_unet__context_refiner_0_feed_forward_w1.alpha + - lora_unet__noise_refiner_1_attention_to_v.lora_up.weight + """ + return any( + isinstance(k, str) and re.match(Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX, k.split(".")[0]) for k in state_dict.keys() + ) + def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool: """Checks if the provided state dict is likely a Z-Image LoRA. @@ -23,6 +47,9 @@ def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder. They may use various prefixes depending on the training framework. """ + if is_state_dict_likely_z_image_kohya_lora(state_dict): + return True + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] # Check for Z-Image transformer keys (S3-DiT architecture) @@ -57,6 +84,7 @@ def lora_model_from_z_image_state_dict( - "transformer." or "base_model.model.transformer." for diffusers PEFT format - "diffusion_model." for some training frameworks - "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder + - "lora_unet__" for Kohya format (underscores instead of dots) Args: state_dict: The LoRA state dict @@ -65,6 +93,10 @@ def lora_model_from_z_image_state_dict( Returns: A ModelPatchRaw containing the LoRA layers """ + # If Kohya format, convert keys first then process normally + if is_state_dict_likely_z_image_kohya_lora(state_dict): + state_dict = _convert_z_image_kohya_state_dict(state_dict) + layers: dict[str, BaseLayerPatch] = {} # Group keys by layer @@ -120,6 +152,45 @@ def lora_model_from_z_image_state_dict( return ModelPatchRaw(layers=layers) +def _convert_z_image_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Converts a Kohya-format Z-Image LoRA state dict to diffusion_model dot-notation. + + Example key conversions: + - lora_unet__layers_0_attention_to_k.lora_down.weight -> diffusion_model.layers.0.attention.to_k.lora_down.weight + - lora_unet__context_refiner_0_feed_forward_w1.alpha -> diffusion_model.context_refiner.0.feed_forward.w1.alpha + - lora_unet__noise_refiner_1_attention_to_v.lora_up.weight -> diffusion_model.noise_refiner.1.attention.to_v.lora_up.weight + """ + converted: Dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(key, str) or not key.startswith("lora_unet__"): + converted[key] = value + continue + + # Split into layer name and param suffix (e.g. "lora_down.weight", "alpha") + layer_name, _, param_suffix = key.partition(".") + + # Strip lora_unet__ prefix + remainder = layer_name[len("lora_unet__") :] + + # Convert Kohya underscore format to dot-notation using the known structure + match = re.match( + r"(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)$", + remainder, + ) + if match: + block, idx, submodule, param = match.groups() + new_layer = f"diffusion_model.{block}.{idx}.{submodule}.{param}" + else: + # Fallback: keep original key for unrecognized patterns + converted[key] = value + continue + + new_key = f"{new_layer}.{param_suffix}" if param_suffix else new_layer + converted[new_key] = value + + return converted + + def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]: """Convert layer dict keys from PEFT format to internal format.""" if "lora_A.weight" in layer_dict: From 18af72c497c2f417d6acbe03917d3f9b48666e23 Mon Sep 17 00:00:00 2001 From: Cocoon-Break <54054995+kuishou68@users.noreply.github.com> Date: Tue, 21 Apr 2026 09:00:19 +0800 Subject: [PATCH 6/7] fix: remove directory entry sizes from directory_size() to report accurate file totals (#9040) * fix: remove dir entry sizes from directory_size() (Closes #9039) * fix: replace unused 'dirs' variable with '_' to resolve ruff F841 linting error --------- Co-authored-by: Lincoln Stein Co-authored-by: Alexander Eichhorn --- invokeai/backend/util/util.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index cc654e4d39..fb8671cec2 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -40,11 +40,9 @@ def directory_size(directory: Path) -> int: Return the aggregate size of all files in a directory (bytes). """ sum = 0 - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for f in files: sum += Path(root, f).stat().st_size - for d in dirs: - sum += Path(root, d).stat().st_size return sum From c83f29362ea11b735dda355d390172a2b9e1c33e Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 21 Apr 2026 03:58:16 +0200 Subject: [PATCH 7/7] fix(flux2-vae): support FLUX.2 small-decoder VAE variant (#9032) Infer encoder and decoder block_out_channels independently from the state dict and rebuild the decoder submodule when its channel widths differ from the encoder, so the asymmetric full_encoder_small_decoder checkpoint from black-forest-labs/FLUX.2-small-decoder loads correctly. Co-authored-by: Lincoln Stein --- .../model_manager/load/model_loaders/flux.py | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index 2de51a8aca..c802154797 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -178,12 +178,43 @@ class Flux2VAELoader(ModelLoader): if is_bfl_format: sd = self._convert_flux2_vae_bfl_to_diffusers(sd) - # FLUX.2 VAE configuration (32 latent channels) - # Based on the official FLUX.2 VAE architecture - # Use default config - AutoencoderKLFlux2 has built-in defaults + # FLUX.2 VAE configuration (32 latent channels). + # The standard FLUX.2 VAE uses block_out_channels=(128,256,512,512) for both + # encoder and decoder. The "small decoder" variant from + # black-forest-labs/FLUX.2-small-decoder keeps the full encoder but uses a + # narrower decoder with channels (96,192,384,384). AutoencoderKLFlux2 only + # exposes a single block_out_channels, so we build the model with the + # encoder's channels and, if the decoder differs, replace just the decoder + # submodule with a matching one before loading the state dict. + encoder_block_out_channels = (128, 256, 512, 512) + decoder_block_out_channels = encoder_block_out_channels + if "encoder.conv_in.weight" in sd and "encoder.conv_norm_out.weight" in sd: + enc_last = int(sd["encoder.conv_norm_out.weight"].shape[0]) + enc_first = int(sd["encoder.conv_in.weight"].shape[0]) + encoder_block_out_channels = (enc_first, enc_first * 2, enc_last, enc_last) + if "decoder.conv_in.weight" in sd and "decoder.conv_norm_out.weight" in sd: + dec_last = int(sd["decoder.conv_in.weight"].shape[0]) + dec_first = int(sd["decoder.conv_norm_out.weight"].shape[0]) + decoder_block_out_channels = (dec_first, dec_first * 2, dec_last, dec_last) + with SilenceWarnings(): with accelerate.init_empty_weights(): - model = AutoencoderKLFlux2() + model = AutoencoderKLFlux2(block_out_channels=encoder_block_out_channels) + if decoder_block_out_channels != encoder_block_out_channels: + # Rebuild the decoder with the smaller channel widths. + from diffusers.models.autoencoders.vae import Decoder + + cfg = model.config + model.decoder = Decoder( + in_channels=cfg.latent_channels, + out_channels=cfg.out_channels, + up_block_types=cfg.up_block_types, + block_out_channels=decoder_block_out_channels, + layers_per_block=cfg.layers_per_block, + norm_num_groups=cfg.norm_num_groups, + act_fn=cfg.act_fn, + mid_block_add_attention=cfg.mid_block_add_attention, + ) # Convert to bfloat16 and load for k in sd.keys():