diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1d99f2cae4..3e7e742934 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1316,9 +1316,11 @@ "zImageQwen3Source": "Qwen3 & VAE Source Model", "zImageQwen3SourcePlaceholder": "Required if VAE/Encoder empty", "flux2KleinVae": "VAE (optional)", - "flux2KleinVaePlaceholder": "From main model", + "flux2KleinVaePlaceholder": "From diffusers model", + "flux2KleinVaeNoModelPlaceholder": "No diffusers model available", "flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)", - "flux2KleinQwen3EncoderPlaceholder": "From main model", + "flux2KleinQwen3EncoderPlaceholder": "From diffusers model", + "flux2KleinQwen3EncoderNoModelPlaceholder": "No diffusers model available", "qwenImageComponentSource": "VAE/Encoder Source (Diffusers)", "qwenImageComponentSourcePlaceholder": "Required for GGUF models", "qwenImageQuantization": "Encoder Quantization", @@ -1623,6 +1625,8 @@ "noFLUXVAEModelSelected": "No VAE model selected for FLUX generation", "noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation", "noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation", + "noFlux2KleinVaeModelSelected": "No VAE selected. Non-diffusers FLUX.2 Klein models require a standalone VAE", + "noFlux2KleinQwen3EncoderModelSelected": "No Qwen3 Encoder selected. Non-diffusers FLUX.2 Klein models require a standalone Qwen3 Encoder", "noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder", "noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model", "noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model", diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index 2b7c0f7d17..1ea7626290 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -35,6 +35,15 @@ type PayloadActionWithId = T extends void } & T >; +/** Fingerprint used to match the same reference image entry after recall when ids are regenerated. */ +/** Empty configs of the same type may collide; the worst case is selecting an equivalent empty entity. */ +const getRefImageRecallMatchKey = (entity: RefImageState): string => { + const { config } = entity; + const imageName = config.image?.original.image.image_name ?? ''; + const modelKey = 'model' in config && config.model ? config.model.key : ''; + return `${config.type}\0${modelKey}\0${imageName}`; +}; + const slice = createSlice({ name: 'refImages', initialState: getInitialRefImagesState(), @@ -54,13 +63,41 @@ const slice = createSlice({ }, refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => { const { entities, replace } = action.payload; - if (replace) { - state.entities = entities; - state.isPanelOpen = false; - state.selectedEntityId = null; - } else { + if (!replace) { state.entities.push(...entities); + return; } + const wasPanelOpen = state.isPanelOpen; + const previousSelectedId = state.selectedEntityId; + let previousEntity: RefImageState | null = null; + if (previousSelectedId !== null) { + previousEntity = state.entities.find((e) => e.id === previousSelectedId) ?? null; + } + state.entities = entities; + if (entities.length === 0) { + state.selectedEntityId = null; + state.isPanelOpen = false; + return; + } + if (!wasPanelOpen) { + state.selectedEntityId = null; + return; + } + const firstEntity = entities[0]; + assert(firstEntity); + if (previousSelectedId === null) { + // Open panel must have a selection; otherwise, fall back to the first entity. + state.selectedEntityId = firstEntity.id; + return; + } + if (previousSelectedId !== null && entities.some((e) => e.id === previousSelectedId)) { + state.selectedEntityId = previousSelectedId; + return; + } + const previousKey = previousEntity ? getRefImageRecallMatchKey(previousEntity) : null; + const matched = + previousKey !== null ? entities.find((e) => getRefImageRecallMatchKey(e) === previousKey) : undefined; + state.selectedEntityId = matched?.id ?? firstEntity.id; }, refImageImageChanged: (state, action: PayloadActionWithId<{ croppableImage: CroppableImageWithDims | null }>) => { const { id, croppableImage } = action.payload; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts new file mode 100644 index 0000000000..7f01becc3d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.test.ts @@ -0,0 +1,370 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('app/logging/logger', () => ({ + logger: () => ({ + debug: vi.fn(), + }), +})); + +let nextId = 0; +vi.mock('features/controlLayers/konva/util', () => ({ + getPrefixedId: (prefix: string) => `${prefix}:${nextId++}`, +})); + +// --- Flux2 Klein model fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-klein-diffusers', + hash: 'flux2-diff-hash', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const flux2GGUFModel = { + key: 'flux2-klein-gguf', + hash: 'flux2-gguf-hash', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +}; + +const kleinVaeModelFixture = { key: 'klein-vae', name: 'Klein VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3EncoderModelFixture = { + key: 'klein-qwen3', + name: 'Qwen3 4B', + base: 'flux2', + type: 'qwen3_encoder', +}; + +const flux2GGUF9BModel = { + key: 'flux2-klein-gguf-9b', + hash: 'flux2-gguf-9b-hash', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +}; + +const diffusersSourceModelFixture = { + key: 'flux2-source-diffusers', + hash: 'flux2-src-hash', + name: 'FLUX.2 Klein 4B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +}; + +const diffusers9BSourceModelFixture = { + key: 'flux2-source-diffusers-9b', + hash: 'flux2-src-9b-hash', + name: 'FLUX.2 Klein 9B Source', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_9b', +}; + +// --- Mutable state --- + +let model: Record = { ...flux2DiffusersModel }; +let kleinVaeModel: Record | null = null; +let kleinQwen3EncoderModel: Record | null = null; +let diffusersModels: Record[] = []; + +vi.mock('features/controlLayers/store/paramsSlice', () => ({ + selectMainModelConfig: vi.fn(() => model), + selectParamsSlice: vi.fn(() => ({ + guidance: 4, + steps: 20, + fluxScheduler: 'euler', + fluxDypePreset: 'off', + fluxDypeScale: 2.0, + fluxDypeExponent: 2.0, + fluxVAE: null, + t5EncoderModel: null, + clipEmbedModel: null, + })), + selectKleinVaeModel: vi.fn(() => kleinVaeModel), + selectKleinQwen3EncoderModel: vi.fn(() => kleinQwen3EncoderModel), +})); + +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ + selectRefImagesSlice: vi.fn(() => ({ + entities: [], + })), +})); + +vi.mock('features/controlLayers/store/selectors', () => ({ + selectCanvasMetadata: vi.fn(() => ({})), + selectCanvasSlice: vi.fn(() => ({})), +})); + +vi.mock('features/controlLayers/store/types', () => ({ + isFlux2ReferenceImageConfig: vi.fn(() => false), + isFluxKontextReferenceImageConfig: vi.fn(() => false), +})); + +vi.mock('features/controlLayers/store/validators', () => ({ + getGlobalReferenceImageWarnings: vi.fn(() => []), +})); + +vi.mock('features/nodes/util/graph/generation/addFlux2KleinLoRAs', () => ({ + addFlux2KleinLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXFill', () => ({ + addFLUXFill: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXLoRAs', () => ({ + addFLUXLoRAs: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addFLUXRedux', () => ({ + addFLUXReduxes: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addImageToImage', () => ({ + addImageToImage: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addInpaint', () => ({ + addInpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addNSFWChecker', () => ({ + addNSFWChecker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addOutpaint', () => ({ + addOutpaint: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addRegions', () => ({ + addRegions: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addTextToImage', () => ({ + addTextToImage: vi.fn(({ l2i }) => l2i), +})); + +vi.mock('features/nodes/util/graph/generation/addWatermarker', () => ({ + addWatermarker: vi.fn((_g, node) => node), +})); + +vi.mock('features/nodes/util/graph/generation/addControlAdapters', () => ({ + addControlLoRA: vi.fn(), + addControlNets: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/generation/addIPAdapters', () => ({ + addIPAdapters: vi.fn(), +})); + +vi.mock('features/nodes/util/graph/graphBuilderUtils', () => ({ + selectCanvasOutputFields: vi.fn(() => ({})), + selectPresetModifiedPrompts: vi.fn(() => ({ + positive: 'a prompt', + negative: '', + })), +})); + +vi.mock('features/ui/store/uiSelectors', () => ({ + selectActiveTab: vi.fn(() => 'generation'), +})); + +vi.mock('services/api/hooks/modelsByType', () => ({ + selectFlux2DiffusersModels: vi.fn(() => diffusersModels), +})); + +vi.mock('services/api/types', async () => { + const actual = await vi.importActual('services/api/types'); + return { + ...actual, + isNonRefinerMainModelConfig: vi.fn(() => true), + }; +}); + +import { buildFLUXGraph } from './buildFLUXGraph'; + +const buildGraphArg = () => ({ + generationMode: 'txt2img' as const, + manager: null, + state: { + system: { + shouldUseNSFWChecker: false, + shouldUseWatermarker: false, + }, + } as never, +}); + +/** Find the flux2_klein_model_loader node in the graph. */ +const getLoaderNode = async () => { + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const loaderEntry = Object.entries(graph.nodes).find(([id]) => id.startsWith('flux2_klein_model_loader:')); + return loaderEntry?.[1] as Record | undefined; +}; + +describe('buildFLUXGraph – FLUX.2 Klein qwen3_source_model', () => { + afterEach(() => { + nextId = 0; + model = { ...flux2DiffusersModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = null; + diffusersModels = []; + }); + + it('does not set qwen3_source_model when main model is diffusers', async () => { + model = { ...flux2DiffusersModel }; + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when main model is GGUF and a diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toEqual({ + key: diffusersSourceModelFixture.key, + hash: diffusersSourceModelFixture.hash, + name: diffusersSourceModelFixture.name, + base: diffusersSourceModelFixture.base, + type: diffusersSourceModelFixture.type, + }); + }); + + it('does not set qwen3_source_model when main model is GGUF but standalone VAE and Qwen3 are both selected', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('does not set qwen3_source_model when main model is GGUF and no diffusers model is available', async () => { + model = { ...flux2GGUFModel }; + diffusersModels = []; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + it('sets qwen3_source_model when only VAE is selected but Qwen3 is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = null; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('sets qwen3_source_model when only Qwen3 is selected but VAE is missing', async () => { + model = { ...flux2GGUFModel }; + kleinVaeModel = null; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.qwen3_source_model).toBeDefined(); + }); + + it('passes standalone vae_model and qwen3_encoder_model when selected', async () => { + model = { ...flux2DiffusersModel }; + kleinVaeModel = kleinVaeModelFixture; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + expect(loader!.vae_model).toEqual(kleinVaeModelFixture); + expect(loader!.qwen3_encoder_model).toEqual(kleinQwen3EncoderModelFixture); + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + + describe('variant matching', () => { + it('selects a variant-matching diffusers model when multiple are available', async () => { + model = { ...flux2GGUF9BModel }; + diffusersModels = [diffusersSourceModelFixture, diffusers9BSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should pick the 9B diffusers model, not the 4B + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusers9BSourceModelFixture.key })); + }); + + it('falls back to any diffusers model for VAE when standalone Qwen3 is selected but no variant match', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = kleinQwen3EncoderModelFixture; + // Only 4B diffusers available, no 9B — but Qwen3 is already provided standalone + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should use the 4B diffusers model just for VAE extraction + expect(loader!.qwen3_source_model).toEqual(expect.objectContaining({ key: diffusersSourceModelFixture.key })); + }); + + it('does not set qwen3_source_model when GGUF 9B with only 4B diffusers available and no standalone Qwen3', async () => { + model = { ...flux2GGUF9BModel }; + kleinQwen3EncoderModel = null; + // Only 4B diffusers available — wrong variant for Qwen3, no standalone Qwen3 selected + diffusersModels = [diffusersSourceModelFixture]; + + const loader = await getLoaderNode(); + expect(loader).toBeDefined(); + // Should NOT use the 4B diffusers since it has the wrong Qwen3 encoder + expect(loader!.qwen3_source_model).toBeUndefined(); + }); + }); + + describe('graph structure', () => { + it('uses flux2_klein_model_loader for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_model_loader:'))).toBe(true); + }); + + it('uses flux2_vae_decode for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_vae_decode:'))).toBe(true); + }); + + it('uses flux2_klein_text_encoder for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeIds = Object.keys(graph.nodes); + expect(nodeIds.some((id) => id.startsWith('flux2_klein_text_encoder:'))).toBe(true); + }); + + it('uses flux2_denoise for flux2 models', async () => { + model = { ...flux2DiffusersModel }; + const { g } = await buildFLUXGraph(buildGraphArg()); + const graph = g.getGraph(); + const nodeTypes = Object.values(graph.nodes).map((n) => n.type); + expect(nodeTypes).toContain('flux2_denoise'); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index ba27e5dbf6..407c921421 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -10,7 +10,8 @@ import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlic import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isFlux2ReferenceImageConfig, isFluxKontextReferenceImageConfig } from 'features/controlLayers/store/types'; import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; -import { zImageField } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { zImageField, zModelIdentifierField } from 'features/nodes/types/common'; import { addFlux2KleinLoRAs } from 'features/nodes/util/graph/generation/addFlux2KleinLoRAs'; import { addFLUXFill } from 'features/nodes/util/graph/generation/addFLUXFill'; import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; @@ -26,8 +27,10 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields } from 'features/nodes/util/graph/graphBuilderUtils'; import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { t } from 'i18next'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -141,7 +144,23 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); + const sourceModel = variantMatch ?? (kleinQwen3EncoderModel ? diffusersModels[0] : undefined); + if (sourceModel) { + qwen3SourceModel = zModelIdentifierField.parse(sourceModel); + } + } + modelLoader = g.addNode({ type: 'flux2_klein_model_loader', id: getPrefixedId('flux2_klein_model_loader'), @@ -149,6 +168,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise { const dispatch = useAppDispatch(); const { t } = useTranslation(); const kleinVaeModel = useAppSelector(selectKleinVaeModel); + const mainModelConfig = useAppSelector(selectMainModelConfig); const [modelConfigs, { isLoading }] = useFlux2VAEModels(); + const [diffusersModels] = useFlux2DiffusersModels(); const _onChange = useCallback( (model: VAEModelConfig | null) => { @@ -42,6 +45,11 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { isLoading, }); + const hasDiffusersSource = mainModelConfig?.format === 'diffusers' || diffusersModels.length > 0; + const placeholder = hasDiffusersSource + ? t('modelManager.flux2KleinVaePlaceholder') + : t('modelManager.flux2KleinVaeNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinVae')} @@ -51,7 +59,7 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinVaePlaceholder')} + placeholder={placeholder} /> ); @@ -59,15 +67,6 @@ const ParamFlux2KleinVaeModelSelect = memo(() => { ParamFlux2KleinVaeModelSelect.displayName = 'ParamFlux2KleinVaeModelSelect'; -/** - * Maps FLUX.2 Klein variants to compatible Qwen3 encoder variants - */ -const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { - klein_4b: 'qwen3_4b', - klein_9b: 'qwen3_8b', - klein_9b_base: 'qwen3_8b', -}; - /** * FLUX.2 Klein Qwen3 Encoder Model Select * Selects a Qwen3 text encoder model for FLUX.2 Klein @@ -79,6 +78,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { const kleinQwen3EncoderModel = useAppSelector(selectKleinQwen3EncoderModel); const mainModelConfig = useAppSelector(selectMainModelConfig); const [allModelConfigs, { isLoading }] = useQwen3EncoderModels(); + const [diffusersModels] = useFlux2DiffusersModels(); // Filter Qwen3 encoders based on the main model's variant const modelConfigs = useMemo(() => { @@ -112,6 +112,20 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { isLoading, }); + // Qwen3 encoder requires a Qwen3-compatible diffusers model (variants that share the same Qwen3 encoder). + const hasMatchingDiffusersSource = + mainModelConfig?.format === 'diffusers' || + diffusersModels.some( + (m) => + 'variant' in m && + mainModelConfig && + 'variant' in mainModelConfig && + isFlux2KleinQwen3Compatible(m.variant, mainModelConfig.variant) + ); + const placeholder = hasMatchingDiffusersSource + ? t('modelManager.flux2KleinQwen3EncoderPlaceholder') + : t('modelManager.flux2KleinQwen3EncoderNoModelPlaceholder'); + return ( {t('modelManager.flux2KleinQwen3Encoder')} @@ -121,7 +135,7 @@ const ParamFlux2KleinQwen3EncoderModelSelect = memo(() => { onChange={onChange} noOptionsMessage={noOptionsMessage} isClearable - placeholder={t('modelManager.flux2KleinQwen3EncoderPlaceholder')} + placeholder={placeholder} /> ); diff --git a/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts new file mode 100644 index 0000000000..b9508a4f82 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/util/flux2Klein.ts @@ -0,0 +1,24 @@ +/** + * Maps a FLUX.2 Klein main-model variant to the Qwen3 encoder variant it uses. + * Multiple Klein variants can share the same Qwen3 variant (e.g. `klein_9b` and + * `klein_9b_base` both use `qwen3_8b`), so two different Klein variants can be + * Qwen3-compatible sources for each other. + */ +export const KLEIN_TO_QWEN3_VARIANT_MAP: Record = { + klein_4b: 'qwen3_4b', + klein_9b: 'qwen3_8b', + klein_9b_base: 'qwen3_8b', +}; + +/** + * Returns true if two Klein variants share the same Qwen3 encoder and can therefore + * be used as a Qwen3 source for each other. + */ +export const isFlux2KleinQwen3Compatible = (variantA: unknown, variantB: unknown): boolean => { + if (typeof variantA !== 'string' || typeof variantB !== 'string') { + return false; + } + const qwen3A = KLEIN_TO_QWEN3_VARIANT_MAP[variantA]; + const qwen3B = KLEIN_TO_QWEN3_VARIANT_MAP[variantB]; + return qwen3A !== undefined && qwen3A === qwen3B; +}; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.test.ts b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts new file mode 100644 index 0000000000..632006050e --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/store/readiness.test.ts @@ -0,0 +1,272 @@ +import { describe, expect, it, vi } from 'vitest'; + +vi.mock('features/dynamicPrompts/util/getShouldProcessPrompt', () => ({ + getShouldProcessPrompt: vi.fn(() => false), +})); + +vi.mock('i18next', () => ({ + default: { + t: (key: string) => key, + }, +})); + +import type { ParamsState, RefImagesState } from 'features/controlLayers/store/types'; +import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; +import type { MainModelConfig } from 'services/api/types'; + +import { getReasonsWhyCannotEnqueueCanvasTab, getReasonsWhyCannotEnqueueGenerateTab } from './readiness'; + +// --- Fixtures --- + +const flux2DiffusersModel = { + key: 'flux2-diff', + hash: 'h', + name: 'FLUX.2 Klein 4B', + base: 'flux2', + type: 'main', + format: 'diffusers', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF4BModel = { + key: 'flux2-gguf-4b', + hash: 'h', + name: 'FLUX.2 Klein 4B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_4b', +} as unknown as MainModelConfig; + +const flux2GGUF9BModel = { + key: 'flux2-gguf-9b', + hash: 'h', + name: 'FLUX.2 Klein 9B GGUF', + base: 'flux2', + type: 'main', + format: 'gguf_quantized', + variant: 'klein_9b', +} as unknown as MainModelConfig; + +const kleinVaeModel = { key: 'vae', name: 'VAE', base: 'flux2', type: 'vae' }; +const kleinQwen3Model = { key: 'qwen3', name: 'Qwen3', base: 'flux2', type: 'qwen3_encoder' }; + +const baseDynamicPrompts: DynamicPromptsState = { + _version: 1, + maxPrompts: 100, + combinatorial: false, + prompts: ['test prompt'], + parsingError: undefined, + isError: false, + isLoading: false, + seedBehaviour: 'PER_PROMPT', +}; + +const baseRefImages: RefImagesState = { + entities: [], + ipAdapters: { entities: [], ids: [] }, +} as unknown as RefImagesState; + +const baseParams = { + positivePrompt: 'test', + kleinVaeModel: null, + kleinQwen3EncoderModel: null, +} as unknown as ParamsState; + +// --- Helpers --- + +const buildGenerateTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const buildCanvasTabArg = (overrides: { + model?: MainModelConfig | null; + kleinVaeModel?: unknown; + kleinQwen3EncoderModel?: unknown; + hasFlux2DiffusersVaeSource?: boolean; + hasFlux2DiffusersQwen3Source?: boolean; +}) => ({ + isConnected: true, + model: overrides.model ?? flux2DiffusersModel, + canvas: { + bbox: { + scaleMethod: 'none', + rect: { width: 1024, height: 1024 }, + scaledSize: { width: 1024, height: 1024 }, + }, + controlLayers: { entities: [] }, + regionalGuidance: { entities: [] }, + rasterLayers: { entities: [] }, + inpaintMasks: { entities: [] }, + }, + params: { + ...baseParams, + kleinVaeModel: overrides.kleinVaeModel ?? null, + kleinQwen3EncoderModel: overrides.kleinQwen3EncoderModel ?? null, + } as unknown as ParamsState, + refImages: baseRefImages, + loras: [], + dynamicPrompts: baseDynamicPrompts, + canvasIsFiltering: false, + canvasIsTransforming: false, + canvasIsRasterizing: false, + canvasIsCompositing: false, + canvasIsSelectingObject: false, + hasFlux2DiffusersVaeSource: overrides.hasFlux2DiffusersVaeSource ?? false, + hasFlux2DiffusersQwen3Source: overrides.hasFlux2DiffusersQwen3Source ?? false, +}); + +const hasFlux2VaeReason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinVaeModelSelected')); + +const hasFlux2Qwen3Reason = (reasons: { content: string }[]) => + reasons.some((r) => r.content.includes('noFlux2KleinQwen3EncoderModelSelected')); + +// --- Tests --- + +describe('FLUX.2 Klein readiness checks – generate tab', () => { + it('no errors when main model is diffusers (VAE/Qwen3 extracted from it)', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2DiffusersModel })); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no diffusers source and no standalone models', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab(buildGenerateTabArg({ model: flux2GGUF4BModel })); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for Qwen3 when GGUF model with standalone VAE but no Qwen3 and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinVaeModel: kleinVaeModel }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('errors only for VAE when GGUF model with standalone Qwen3 but no VAE and no diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ model: flux2GGUF4BModel, kleinQwen3EncoderModel: kleinQwen3Model }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B model with only a 4B diffusers source (variant mismatch)', () => { + // User has Klein 9B GGUF selected, only a 4B diffusers model installed. + // VAE is shared across variants so it's ok. Qwen3 encoder differs, so it's not ok. + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF 9B model with variant-matching diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueGenerateTab( + buildGenerateTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); +}); + +describe('FLUX.2 Klein readiness checks – canvas tab', () => { + it('no errors when main model is diffusers', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2DiffusersModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('no errors when GGUF model with both VAE and Qwen3 diffusers sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: true, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('errors for both VAE and Qwen3 when GGUF model with no sources', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab(buildCanvasTabArg({ model: flux2GGUF4BModel }) as never); + expect(hasFlux2VaeReason(reasons)).toBe(true); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); + + it('no errors when GGUF model with both standalone VAE and Qwen3', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF4BModel, + kleinVaeModel: kleinVaeModel, + kleinQwen3EncoderModel: kleinQwen3Model, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(false); + }); + + it('VAE ok but Qwen3 errors when GGUF 9B with variant-mismatched diffusers source', () => { + const reasons = getReasonsWhyCannotEnqueueCanvasTab( + buildCanvasTabArg({ + model: flux2GGUF9BModel, + hasFlux2DiffusersVaeSource: true, + hasFlux2DiffusersQwen3Source: false, + }) as never + ); + expect(hasFlux2VaeReason(reasons)).toBe(false); + expect(hasFlux2Qwen3Reason(reasons)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 67dfe3141c..5802a2aed5 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -33,12 +33,14 @@ import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/ import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue'; import type { UpscaleState } from 'features/parameters/store/upscaleSlice'; import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice'; +import { isFlux2KleinQwen3Compatible } from 'features/parameters/util/flux2Klein'; import { getGridSize } from 'features/parameters/util/optimalDimension'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import type { TabName } from 'features/ui/store/uiTypes'; import i18n from 'i18next'; import { atom, computed } from 'nanostores'; import { useEffect } from 'react'; +import { selectFlux2DiffusersModels } from 'services/api/hooks/modelsByType'; import type { MainOrExternalModelConfig } from 'services/api/types'; import { isExternalApiModelConfig } from 'services/api/types'; import { $isConnected } from 'services/events/stores'; @@ -109,6 +111,12 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { } = arg; if (tab === 'generate') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueGenerateTab({ isConnected, model, @@ -116,10 +124,18 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { refImages, dynamicPrompts, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'canvas') { const model = selectMainModelConfig(store.getState()); + const flux2DiffusersModels = selectFlux2DiffusersModels(store.getState()); + const hasFlux2DiffusersVaeSource = flux2DiffusersModels.length > 0; + const modelVariant = model && 'variant' in model ? model.variant : undefined; + const hasFlux2DiffusersQwen3Source = flux2DiffusersModels.some( + (m) => 'variant' in m && isFlux2KleinQwen3Compatible(m.variant, modelVariant) + ); const reasons = await getReasonsWhyCannotEnqueueCanvasTab({ isConnected, model, @@ -133,6 +149,8 @@ const debouncedUpdateReasons = debounce(async (arg: UpdateReasonsArg) => { canvasIsCompositing, canvasIsSelectingObject, loras, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'workflows') { @@ -220,15 +238,26 @@ export const useReadinessWatcher = () => { const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') }); -const getReasonsWhyCannotEnqueueGenerateTab = (arg: { +export const getReasonsWhyCannotEnqueueGenerateTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; params: ParamsState; refImages: RefImagesState; loras: LoRA[]; dynamicPrompts: DynamicPromptsState; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { - const { isConnected, model, params, refImages, loras, dynamicPrompts } = arg; + const { + isConnected, + model, + params, + refImages, + loras, + dynamicPrompts, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, + } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -260,7 +289,17 @@ const getReasonsWhyCannotEnqueueGenerateTab = (arg: { } } - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + if (model?.base === 'flux2' && model.format !== 'diffusers') { + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } if (model?.base === 'qwen-image' && model.format === 'gguf_quantized') { if (!params.qwenImageComponentSource) { @@ -452,7 +491,7 @@ const getReasonsWhyCannotEnqueueUpscaleTab = (arg: { return reasons; }; -const getReasonsWhyCannotEnqueueCanvasTab = (arg: { +export const getReasonsWhyCannotEnqueueCanvasTab = (arg: { isConnected: boolean; model: MainOrExternalModelConfig | null | undefined; canvas: CanvasState; @@ -465,6 +504,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing: boolean; canvasIsCompositing: boolean; canvasIsSelectingObject: boolean; + hasFlux2DiffusersVaeSource: boolean; + hasFlux2DiffusersQwen3Source: boolean; }) => { const { isConnected, @@ -479,6 +520,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { canvasIsRasterizing, canvasIsCompositing, canvasIsSelectingObject, + hasFlux2DiffusersVaeSource, + hasFlux2DiffusersQwen3Source, } = arg; const { positivePrompt } = params; const reasons: Reason[] = []; @@ -571,7 +614,17 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { } if (model?.base === 'flux2') { - // FLUX.2 (Klein) extracts Qwen3 encoder and VAE from main model - no separate selections needed + // Non-diffusers FLUX.2 Klein models require standalone VAE and Qwen3 Encoder + // unless a diffusers flux2 model is available to extract them from. + // VAE is shared across variants, but Qwen3 encoder requires a variant-matching diffusers model. + if (model.format !== 'diffusers') { + if (!params.kleinVaeModel && !hasFlux2DiffusersVaeSource) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinVaeModelSelected') }); + } + if (!params.kleinQwen3EncoderModel && !hasFlux2DiffusersQwen3Source) { + reasons.push({ content: i18n.t('parameters.invoke.noFlux2KleinQwen3EncoderModelSelected') }); + } + } const { bbox } = canvas; const gridSize = getGridSize('flux'); // FLUX.2 uses same grid size as FLUX.1 diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 55746e5294..2496c06ed0 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -18,6 +18,7 @@ import { isControlNetModelConfig, isExternalApiModelConfig, isFlux1VAEModelConfig, + isFlux2DiffusersMainModelConfig, isFlux2VAEModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig, @@ -101,6 +102,7 @@ export const useFlux2VAEModels = () => buildModelsHook(isFlux2VAEModelConfig)(); export const useAnimaVAEModels = () => buildModelsHook(isAnimaVAEModelConfig)(); export const useAnimaQwen3EncoderModels = () => buildModelsHook(isAnimaQwen3EncoderModelConfig)(); export const useZImageDiffusersModels = () => buildModelsHook(isZImageDiffusersMainModelConfig)(); +export const useFlux2DiffusersModels = () => buildModelsHook(isFlux2DiffusersMainModelConfig)(); export const useQwenImageDiffusersModels = () => buildModelsHook(isQwenImageDiffusersMainModelConfig)(); export const useQwen3EncoderModels = () => buildModelsHook(isQwen3EncoderModelConfig)(); export const useGlobalReferenceImageModels = buildModelsHook( @@ -140,6 +142,7 @@ export const selectAnimaQwen3EncoderModels = buildModelsSelector(isAnimaQwen3Enc export const selectQwen3EncoderModels = buildModelsSelector(isQwen3EncoderModelConfig); export const selectQwenImageDiffusersModels = buildModelsSelector(isQwenImageDiffusersMainModelConfig); export const selectZImageDiffusersModels = buildModelsSelector(isZImageDiffusersMainModelConfig); +export const selectFlux2DiffusersModels = buildModelsSelector(isFlux2DiffusersMainModelConfig); export const selectFluxVAEModels = buildModelsSelector(isFluxVAEModelConfig); export const selectAnimaVAEModels = buildModelsSelector(isAnimaVAEModelConfig); export const selectT5EncoderModels = buildModelsSelector(isT5EncoderModelConfigOrSubmodel); diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 3624d7ef6a..9deefada23 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -457,6 +457,10 @@ export const isZImageDiffusersMainModelConfig = (config: AnyModelConfig): config return config.type === 'main' && config.base === 'z-image' && config.format === 'diffusers'; }; +export const isFlux2DiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base === 'flux2' && config.format === 'diffusers'; +}; + export const isQwenImageDiffusersMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'qwen-image' && config.format === 'diffusers'; };