feat(ui): FLUX linear - add VAE as required model field rather than allowing default

This commit is contained in:
Mary Hipp
2024-09-11 15:25:38 -04:00
committed by psychedelicious
parent ffbf4aba1f
commit 573c7d2088
7 changed files with 81 additions and 4 deletions

View File

@@ -931,6 +931,7 @@
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
"canvasManagerNotLoaded": "Canvas Manager not loaded",
"canvasIsFiltering": "Canvas is filtering",

View File

@@ -10,7 +10,7 @@ import {
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { clipEmbedModelSelected, modelChanged, refinerModelChanged, t5EncoderModelSelected, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { clipEmbedModelSelected, fluxVAESelected, modelChanged, refinerModelChanged, t5EncoderModelSelected, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
@@ -23,8 +23,10 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonFluxVAEModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
@@ -54,6 +56,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleIPAdapterModels(models, state, dispatch, log);
handleT5EncoderModels(models, state, dispatch, log)
handleCLIPEmbedModels(models, state, dispatch, log)
handleFLUXVAEModels(models, state, dispatch, log)
},
});
};
@@ -135,7 +138,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const vaeModels = models.filter(isVAEModelConfig);
const vaeModels = models.filter(isNonFluxVAEModelConfig);
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
@@ -255,3 +258,17 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, _log) => {
dispatch(clipEmbedModelSelected(firstModel));
}
};
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, _log) => {
const { fluxVAE: currentFLUXVAEModel } = state.params;
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
const firstModel = fluxVAEModels[0] || null;
const isCurrentFLUXVAEModelAvailable = currentFLUXVAEModel
? fluxVAEModels.some((m) => m.key === currentFLUXVAEModel.key)
: false;
if (!isCurrentFLUXVAEModelAvailable) {
dispatch(fluxVAESelected(firstModel));
}
};

View File

@@ -155,6 +155,9 @@ const createSelector = (
if (!params.clipEmbedModel) {
reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
}
if (!params.fluxVAE) {
reasons.push({ content: i18n.t('parameters.invoke.noFLUXVAEModelSelected') });
}
}
canvas.controlLayers.entities

View File

@@ -48,6 +48,7 @@ export type ParamsState = {
model: ParameterModel | null;
vae: ParameterVAEModel | null;
vaePrecision: ParameterPrecision;
fluxVAE: ParameterVAEModel | null;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
clipSkip: number;
@@ -89,6 +90,7 @@ const initialState: ParamsState = {
steps: 50,
model: null,
vae: null,
fluxVAE: null,
vaePrecision: 'fp32',
seamlessXAxis: false,
seamlessYAxis: false,
@@ -173,6 +175,9 @@ export const paramsSlice = createSlice({
// null is a valid VAE!
state.vae = action.payload;
},
fluxVAESelected: (state, action: PayloadAction<ParameterVAEModel | null>) => {
state.fluxVAE = action.payload;
},
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
},
@@ -272,6 +277,7 @@ export const {
setSeamlessYAxis,
setShouldRandomizeSeed,
vaeSelected,
fluxVAESelected,
vaePrecisionChanged,
t5EncoderModelSelected,
clipEmbedModelSelected,
@@ -315,6 +321,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);
export const selectVAE = createParamsSelector((params) => params.vae);
export const selectFLUXVAE = createParamsSelector((params) => params.fluxVAE);
export const selectVAEKey = createParamsSelector((params) => params.vae?.key);
export const selectT5EncoderModel = createParamsSelector((params) => params.t5EncoderModel);
export const selectCLIPEmbedModel = createParamsSelector((params) => params.clipEmbedModel);

View File

@@ -0,0 +1,44 @@
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fluxVAESelected, selectFLUXVAE } from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
const ParamFLUXVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const vae = useAppSelector(selectFLUXVAE);
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
if (vae) {
dispatch(fluxVAESelected(zModelIdentifierField.parse(vae)));
}
},
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: vae,
isLoading,
});
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
<InformationalPopover feature="paramVAE">
<FormLabel m={0}>{t('modelManager.vae')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
</FormControl>
);
};
export default memo(ParamFLUXVAEModelSelect);

View File

@@ -12,6 +12,7 @@ import { ParamSeedNumberInput } from 'features/parameters/components/Seed/ParamS
import { ParamSeedRandomize } from 'features/parameters/components/Seed/ParamSeedRandomize';
import { ParamSeedShuffle } from 'features/parameters/components/Seed/ParamSeedShuffle';
import ParamVAEModelSelect from 'features/parameters/components/VAEModel/ParamVAEModelSelect';
import ParamFLUXVAEModelSelect from 'features/parameters/components/VAEModel/ParamFLUXVAEModelSelect';
import ParamVAEPrecision from 'features/parameters/components/VAEModel/ParamVAEPrecision';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -86,7 +87,7 @@ export const AdvancedSettingsAccordion = memo(() => {
<StandaloneAccordion label={t('accordions.advanced.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
<Flex gap={4} w="full">
<ParamVAEModelSelect />
{isFLUX ? <ParamFLUXVAEModelSelect /> : <ParamVAEModelSelect />}
{!isFLUX && <ParamVAEPrecision />}
</Flex>
{activeTabName === 'upscaling' ? (
@@ -112,7 +113,7 @@ export const AdvancedSettingsAccordion = memo(() => {
</>
)}
{isFLUX && (
<FormControlGroup formLabelProps={formLabelProps}>
<FormControlGroup>
<ParamT5EncoderModelSelect />
<ParamCLIPEmbedModelSelect />
</FormControlGroup>

View File

@@ -83,6 +83,10 @@ export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConf
return config.type === 'vae';
};
export const isNonFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base !== 'flux';
};
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base === 'flux';
};