mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-05 13:04:59 -05:00
feat(ui): new fields and param layout for FLUX settings: guidance, t5 encoder, CLIP embed
This commit is contained in:
committed by
psychedelicious
parent
8916036ed3
commit
ffbf4aba1f
@@ -907,6 +907,7 @@
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"globalSettings": "Global Settings",
|
||||
"guidance": "Guidance",
|
||||
"height": "Height",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"images": "Images",
|
||||
@@ -929,6 +930,8 @@
|
||||
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
|
||||
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
|
||||
"noModelSelected": "No model selected",
|
||||
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
|
||||
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
|
||||
"canvasManagerNotLoaded": "Canvas Manager not loaded",
|
||||
"canvasIsFiltering": "Canvas is filtering",
|
||||
"canvasIsTransforming": "Canvas is transforming",
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
rgIPAdapterModelChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { clipEmbedModelSelected, modelChanged, refinerModelChanged, t5EncoderModelSelected, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
|
||||
@@ -21,12 +21,14 @@ import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
@@ -50,6 +52,8 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
handleSpandrelImageToImageModels(models, state, dispatch, log);
|
||||
handleIPAdapterModels(models, state, dispatch, log);
|
||||
handleT5EncoderModels(models, state, dispatch, log)
|
||||
handleCLIPEmbedModels(models, state, dispatch, log)
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -223,3 +227,31 @@ const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch,
|
||||
dispatch(postProcessingModelChanged(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const { t5EncoderModel: currentT5EncoderModel } = state.params;
|
||||
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
|
||||
const firstModel = t5EncoderModels[0] || null;
|
||||
|
||||
const isCurrentT5EncoderModelAvailable = currentT5EncoderModel
|
||||
? t5EncoderModels.some((m) => m.key === currentT5EncoderModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentT5EncoderModelAvailable) {
|
||||
dispatch(t5EncoderModelSelected(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const { clipEmbedModel: currentCLIPEmbedModel } = state.params;
|
||||
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
|
||||
const firstModel = CLIPEmbedModels[0] || null;
|
||||
|
||||
const isCurrentCLIPEmbedModelAvailable = currentCLIPEmbedModel
|
||||
? CLIPEmbedModels.some((m) => m.key === currentCLIPEmbedModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentCLIPEmbedModelAvailable) {
|
||||
dispatch(clipEmbedModelSelected(firstModel));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -114,6 +114,9 @@ export type AppConfig = {
|
||||
weight: NumericalParameterConfig;
|
||||
};
|
||||
};
|
||||
flux: {
|
||||
guidance: NumericalParameterConfig
|
||||
}
|
||||
};
|
||||
|
||||
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;
|
||||
|
||||
@@ -147,6 +147,16 @@ const createSelector = (
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
if (model?.base === 'flux') {
|
||||
console.log({ params })
|
||||
if (!params.t5EncoderModel) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') });
|
||||
}
|
||||
if (!params.clipEmbedModel) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
|
||||
}
|
||||
}
|
||||
|
||||
canvas.controlLayers.entities
|
||||
.filter((controlLayer) => controlLayer.isEnabled)
|
||||
.forEach((controlLayer, i) => {
|
||||
|
||||
@@ -20,6 +20,9 @@ import type {
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterVAEModel,
|
||||
ParameterGuidance,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterCLIPEmbedModel
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { clamp } from 'lodash-es';
|
||||
|
||||
@@ -35,6 +38,7 @@ export type ParamsState = {
|
||||
infillColorValue: RgbaColor;
|
||||
cfgScale: ParameterCFGScale;
|
||||
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||
guidance: ParameterGuidance;
|
||||
img2imgStrength: ParameterStrength;
|
||||
iterations: number;
|
||||
scheduler: ParameterScheduler;
|
||||
@@ -60,6 +64,8 @@ export type ParamsState = {
|
||||
refinerPositiveAestheticScore: number;
|
||||
refinerNegativeAestheticScore: number;
|
||||
refinerStart: number;
|
||||
t5EncoderModel: ParameterT5EncoderModel | null,
|
||||
clipEmbedModel: ParameterCLIPEmbedModel | null
|
||||
};
|
||||
|
||||
const initialState: ParamsState = {
|
||||
@@ -74,6 +80,7 @@ const initialState: ParamsState = {
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
cfgScale: 7.5,
|
||||
cfgRescaleMultiplier: 0,
|
||||
guidance: 4,
|
||||
img2imgStrength: 0.75,
|
||||
iterations: 1,
|
||||
scheduler: 'euler',
|
||||
@@ -99,6 +106,8 @@ const initialState: ParamsState = {
|
||||
refinerPositiveAestheticScore: 6,
|
||||
refinerNegativeAestheticScore: 2.5,
|
||||
refinerStart: 0.8,
|
||||
t5EncoderModel: null,
|
||||
clipEmbedModel: null
|
||||
};
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
@@ -114,6 +123,9 @@ export const paramsSlice = createSlice({
|
||||
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
|
||||
state.cfgScale = action.payload;
|
||||
},
|
||||
setGuidance: (state, action: PayloadAction<ParameterGuidance>) => {
|
||||
state.guidance = action.payload;
|
||||
},
|
||||
setCfgRescaleMultiplier: (state, action: PayloadAction<ParameterCFGRescaleMultiplier>) => {
|
||||
state.cfgRescaleMultiplier = action.payload;
|
||||
},
|
||||
@@ -161,6 +173,12 @@ export const paramsSlice = createSlice({
|
||||
// null is a valid VAE!
|
||||
state.vae = action.payload;
|
||||
},
|
||||
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
||||
state.t5EncoderModel = action.payload;
|
||||
},
|
||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||
state.clipEmbedModel = action.payload;
|
||||
},
|
||||
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
|
||||
state.vaePrecision = action.payload;
|
||||
},
|
||||
@@ -246,6 +264,7 @@ export const {
|
||||
setSteps,
|
||||
setCfgScale,
|
||||
setCfgRescaleMultiplier,
|
||||
setGuidance,
|
||||
setScheduler,
|
||||
setSeed,
|
||||
setImg2imgStrength,
|
||||
@@ -254,6 +273,8 @@ export const {
|
||||
setShouldRandomizeSeed,
|
||||
vaeSelected,
|
||||
vaePrecisionChanged,
|
||||
t5EncoderModelSelected,
|
||||
clipEmbedModelSelected,
|
||||
setClipSkip,
|
||||
shouldUseCpuNoiseChanged,
|
||||
positivePromptChanged,
|
||||
@@ -289,11 +310,16 @@ export const createParamsSelector = <T>(selector: Selector<ParamsState, T>) =>
|
||||
|
||||
export const selectBase = createParamsSelector((params) => params.model?.base);
|
||||
export const selectIsSDXL = createParamsSelector((params) => params.model?.base === 'sdxl');
|
||||
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
|
||||
|
||||
export const selectModel = createParamsSelector((params) => params.model);
|
||||
export const selectModelKey = createParamsSelector((params) => params.model?.key);
|
||||
export const selectVAE = createParamsSelector((params) => params.vae);
|
||||
export const selectVAEKey = createParamsSelector((params) => params.vae?.key);
|
||||
export const selectT5EncoderModel = createParamsSelector((params) => params.t5EncoderModel);
|
||||
export const selectCLIPEmbedModel = createParamsSelector((params) => params.clipEmbedModel);
|
||||
export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
|
||||
export const selectGuidance = createParamsSelector((params) => params.guidance);
|
||||
export const selectSteps = createParamsSelector((params) => params.steps);
|
||||
export const selectCFGRescaleMultiplier = createParamsSelector((params) => params.cfgRescaleMultiplier);
|
||||
export const selectCLIPSKip = createParamsSelector((params) => params.clipSkip);
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
setScheduler,
|
||||
setSeed,
|
||||
setSteps,
|
||||
t5EncoderModelSelected,
|
||||
vaeSelected,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
@@ -44,6 +45,7 @@ import type {
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@@ -154,6 +156,10 @@ const recallVAE: MetadataRecallFunc<ParameterVAEModel | null | undefined> = (vae
|
||||
getStore().dispatch(vaeSelected(vaeModel));
|
||||
};
|
||||
|
||||
const recallT5Encoder: MetadataRecallFunc<ParameterT5EncoderModel> = (t5EncoderModel) => {
|
||||
getStore().dispatch(t5EncoderModelSelected(t5EncoderModel));
|
||||
};
|
||||
|
||||
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||
getStore().dispatch(loraRecalled({ lora }));
|
||||
};
|
||||
@@ -196,4 +202,5 @@ export const recallers = {
|
||||
vae: recallVAE,
|
||||
lora: recallLoRA,
|
||||
loras: recallAllLoRAs,
|
||||
t5EncoderModel: recallT5Encoder
|
||||
} as const;
|
||||
|
||||
@@ -7,7 +7,7 @@ import type {
|
||||
T2IAdapterConfigMetadata,
|
||||
} from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ParameterSDXLRefinerModel, ParameterT5EncoderModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { BaseModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
@@ -40,6 +40,13 @@ const validateVAEModel: MetadataValidateFunc<ParameterVAEModel> = (vaeModel) =>
|
||||
});
|
||||
};
|
||||
|
||||
const validateT5EncoderModel: MetadataValidateFunc<ParameterT5EncoderModel> = (t5EncoderModel) => {
|
||||
validateBaseCompatibility('flux', 'T5 Encoder incompatible with currently-selected model');
|
||||
return new Promise((resolve) => {
|
||||
resolve(t5EncoderModel);
|
||||
});
|
||||
};
|
||||
|
||||
const validateLoRA: MetadataValidateFunc<LoRA> = (lora) => {
|
||||
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
|
||||
return new Promise((resolve) => {
|
||||
@@ -131,6 +138,7 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
|
||||
export const validators = {
|
||||
refinerModel: validateRefinerModel,
|
||||
vaeModel: validateVAEModel,
|
||||
t5EncoderModel: validateT5EncoderModel,
|
||||
lora: validateLoRA,
|
||||
loras: validateLoRAs,
|
||||
controlNet: validateControlNet,
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useClipEmbedModels,
|
||||
useCLIPEmbedModels,
|
||||
useControlNetModels,
|
||||
useEmbeddingModels,
|
||||
useIPAdapterModels,
|
||||
@@ -85,7 +85,7 @@ const ModelList = () => {
|
||||
[t5EncoderModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useClipEmbedModels();
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels();
|
||||
const filteredClipEmbedModels = useMemo(
|
||||
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
|
||||
[clipEmbedModels, searchTerm, filteredModelType]
|
||||
|
||||
@@ -5,8 +5,8 @@ import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ClipEmbedModelConfig } from 'services/api/types';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -17,9 +17,9 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useClipEmbedModels();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: ClipEmbedModelConfig | null) => {
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { clipEmbedModelSelected, selectCLIPEmbedModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
import { useModelCombobox } from '../../../../common/hooks/useModelCombobox';
|
||||
|
||||
const ParamCLIPEmbedModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const clipEmbedModel = useAppSelector(selectCLIPEmbedModel);
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
|
||||
if (clipEmbedModel) {
|
||||
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: clipEmbedModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
|
||||
<FormLabel m={0}>{t('modelManager.clipEmbed')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamCLIPEmbedModelSelect);
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectT5EncoderModel, t5EncoderModelSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
import { useModelCombobox } from '../../../../common/hooks/useModelCombobox';
|
||||
|
||||
const ParamT5EncoderModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const t5EncoderModel = useAppSelector(selectT5EncoderModel);
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
if (t5EncoderModel) {
|
||||
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: t5EncoderModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
|
||||
<FormLabel m={0}>{t('modelManager.t5Encoder')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamT5EncoderModelSelect);
|
||||
@@ -0,0 +1,49 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectGuidance, setGuidance } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectGuidanceConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamGuidance = () => {
|
||||
const guidance = useAppSelector(selectGuidance);
|
||||
const config = useAppSelector(selectGuidanceConfig);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(
|
||||
() => [
|
||||
config.sliderMin,
|
||||
Math.floor(config.sliderMax - (config.sliderMax - config.sliderMin) / 2),
|
||||
config.sliderMax,
|
||||
],
|
||||
[config.sliderMax, config.sliderMin]
|
||||
);
|
||||
const onChange = useCallback((v: number) => dispatch(setGuidance(v)), [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.guidance')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={guidance}
|
||||
defaultValue={config.initial}
|
||||
min={config.sliderMin}
|
||||
max={config.sliderMax}
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={guidance}
|
||||
defaultValue={config.initial}
|
||||
min={config.numberInputMin}
|
||||
max={config.numberInputMax}
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamGuidance);
|
||||
@@ -7,14 +7,14 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSDMainModels } from 'services/api/hooks/modelsByType';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModel = useAppSelector(selectModel);
|
||||
const [modelConfigs, { isLoading }] = useSDMainModels();
|
||||
const [modelConfigs, { isLoading }] = useMainModels();
|
||||
const tooltipLabel = useMemo(() => {
|
||||
if (!modelConfigs.length || !selectedModel) {
|
||||
return;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { createParamsSelector } from 'features/controlLayers/store/paramsSlice';
|
||||
import { createParamsSelector, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
|
||||
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
|
||||
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
|
||||
@@ -15,11 +15,12 @@ const selectWithStylePrompts = createParamsSelector((params) => {
|
||||
|
||||
export const Prompts = memo(() => {
|
||||
const withStylePrompts = useAppSelector(selectWithStylePrompts);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<ParamPositivePrompt />
|
||||
{withStylePrompts && <ParamSDXLPositiveStylePrompt />}
|
||||
<ParamNegativePrompt />
|
||||
{!isFLUX && <ParamNegativePrompt />}
|
||||
{withStylePrompts && <ParamSDXLNegativeStylePrompt />}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -56,6 +56,13 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
||||
zParameterCFGScale.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Guidance parameter
|
||||
const zParameterGuidance = z.number().min(1);
|
||||
export type ParameterGuidance = z.infer<typeof zParameterGuidance>;
|
||||
export const isParameterGuidance = (val: unknown): val is ParameterGuidance =>
|
||||
zParameterGuidance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region CFG Rescale Multiplier
|
||||
const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||
export type ParameterCFGRescaleMultiplier = z.infer<typeof zParameterCFGRescaleMultiplier>;
|
||||
@@ -106,6 +113,16 @@ export const zParameterVAEModel = zModelIdentifierField;
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
// #endregion
|
||||
|
||||
// #region T5Encoder Model
|
||||
export const zParameterT5EncoderModel = zModelIdentifierField;
|
||||
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;
|
||||
// #endregion
|
||||
|
||||
// #region CLIP embed Model
|
||||
export const zParameterCLIPEmbedModel = zModelIdentifierField;
|
||||
export type ParameterCLIPEmbedModel = z.infer<typeof zParameterCLIPEmbedModel>;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA Model
|
||||
const zParameterLoRAModel = zModelIdentifierField;
|
||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-libra
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
|
||||
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
|
||||
import ParamClipSkip from 'features/parameters/components/Advanced/ParamClipSkip';
|
||||
import ParamSeamlessXAxis from 'features/parameters/components/Seamless/ParamSeamlessXAxis';
|
||||
@@ -18,6 +18,8 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import ParamT5EncoderModelSelect from '../../../parameters/components/Advanced/ParamT5EncoderModelSelect';
|
||||
import ParamCLIPEmbedModelSelect from '../../../parameters/components/Advanced/ParamCLIPEmbedModelSelect';
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
minW: '9.2rem',
|
||||
@@ -31,32 +33,44 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
const vaeKey = useAppSelector(selectVAEKey);
|
||||
const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectParamsSlice, (params) => {
|
||||
createMemoizedSelector([selectParamsSlice, selectIsFLUX], (params, isFLUX) => {
|
||||
const badges: (string | number)[] = [];
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (params.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${params.vaePrecision}`;
|
||||
if (isFLUX) {
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (params.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${params.vaePrecision}`;
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
}
|
||||
} else {
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (params.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${params.vaePrecision}`;
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
} else if (params.vaePrecision === 'fp16') {
|
||||
badges.push(`VAE ${params.vaePrecision}`);
|
||||
}
|
||||
if (params.clipSkip) {
|
||||
badges.push(`Skip ${params.clipSkip}`);
|
||||
}
|
||||
if (params.cfgRescaleMultiplier) {
|
||||
badges.push(`Rescale ${params.cfgRescaleMultiplier}`);
|
||||
}
|
||||
if (params.seamlessXAxis || params.seamlessYAxis) {
|
||||
badges.push('seamless');
|
||||
}
|
||||
if (activeTabName === 'upscaling' && !params.shouldRandomizeSeed) {
|
||||
badges.push('Manual Seed');
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
} else if (params.vaePrecision === 'fp16') {
|
||||
badges.push(`VAE ${params.vaePrecision}`);
|
||||
}
|
||||
if (params.clipSkip) {
|
||||
badges.push(`Skip ${params.clipSkip}`);
|
||||
}
|
||||
if (params.cfgRescaleMultiplier) {
|
||||
badges.push(`Rescale ${params.cfgRescaleMultiplier}`);
|
||||
}
|
||||
if (params.seamlessXAxis || params.seamlessYAxis) {
|
||||
badges.push('seamless');
|
||||
}
|
||||
if (activeTabName === 'upscaling' && !params.shouldRandomizeSeed) {
|
||||
badges.push('Manual Seed');
|
||||
}
|
||||
|
||||
return badges;
|
||||
}),
|
||||
[vaeConfig, activeTabName]
|
||||
@@ -73,27 +87,36 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
|
||||
<Flex gap={4} w="full">
|
||||
<ParamVAEModelSelect />
|
||||
<ParamVAEPrecision />
|
||||
{!isFLUX && <ParamVAEPrecision />}
|
||||
</Flex>
|
||||
{activeTabName === 'upscaling' && (
|
||||
{activeTabName === 'upscaling' ? (
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamSeedNumberInput />
|
||||
<ParamSeedShuffle />
|
||||
<ParamSeedRandomize />
|
||||
</Flex>
|
||||
)}
|
||||
{activeTabName !== 'upscaling' && (
|
||||
) : (
|
||||
<>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamClipSkip />
|
||||
<ParamCFGRescaleMultiplier />
|
||||
</FormControlGroup>
|
||||
<Flex gap={4} w="full">
|
||||
<FormControlGroup formLabelProps={formLabelProps2}>
|
||||
<ParamSeamlessXAxis />
|
||||
<ParamSeamlessYAxis />
|
||||
{!isFLUX && (
|
||||
<>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamClipSkip />
|
||||
<ParamCFGRescaleMultiplier />
|
||||
</FormControlGroup>
|
||||
<Flex gap={4} w="full">
|
||||
<FormControlGroup formLabelProps={formLabelProps2}>
|
||||
<ParamSeamlessXAxis />
|
||||
<ParamSeamlessYAxis />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{isFLUX && (
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamT5EncoderModelSelect />
|
||||
<ParamCLIPEmbedModelSelect />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
@@ -18,6 +18,8 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
|
||||
import ParamGuidance from '../../../parameters/components/Core/ParamGuidance';
|
||||
import { selectIsFLUX } from '../../../controlLayers/store/paramsSlice';
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
minW: '4rem',
|
||||
@@ -27,6 +29,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoRAsSlice, (loras) => {
|
||||
@@ -71,9 +74,9 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamScheduler />
|
||||
{!isFLUX && <ParamScheduler />}
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
{isFLUX ? <ParamGuidance /> : <ParamCFGScale />}
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</Expander>
|
||||
|
||||
@@ -167,6 +167,17 @@ const initialConfigState: AppConfig = {
|
||||
},
|
||||
},
|
||||
},
|
||||
flux: {
|
||||
guidance: {
|
||||
initial: 4,
|
||||
sliderMin: 2,
|
||||
sliderMax: 6,
|
||||
numberInputMin: 2,
|
||||
numberInputMax: 6,
|
||||
fineStep: 0.1,
|
||||
coarseStep: 1,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const configSlice = createSlice({
|
||||
@@ -188,6 +199,7 @@ export const selectWidthConfig = createConfigSelector((config) => config.sd.widt
|
||||
export const selectHeightConfig = createConfigSelector((config) => config.sd.height);
|
||||
export const selectStepsConfig = createConfigSelector((config) => config.sd.steps);
|
||||
export const selectCFGScaleConfig = createConfigSelector((config) => config.sd.guidance);
|
||||
export const selectGuidanceConfig = createConfigSelector((config) => config.flux.guidance);
|
||||
export const selectCLIPSkipConfig = createConfigSelector((config) => config.sd.clipSkip);
|
||||
export const selectCFGRescaleMultiplierConfig = createConfigSelector((config) => config.sd.cfgRescaleMultiplier);
|
||||
export const selectCanvasCoherenceEdgeSizeConfig = createConfigSelector((config) => config.sd.canvasCoherenceEdgeSize);
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isClipEmbedModelConfig,
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
@@ -30,18 +30,18 @@ import {
|
||||
|
||||
const buildModelsHook =
|
||||
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
|
||||
() => {
|
||||
const result = useGetModelConfigsQuery(undefined);
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!result.data) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
() => {
|
||||
const result = useGetModelConfigsQuery(undefined);
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!result.data) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
|
||||
}, [result]);
|
||||
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
|
||||
}, [result]);
|
||||
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
|
||||
export const useSDMainModels = buildModelsHook(isNonRefinerNonFluxMainModelConfig);
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
@@ -54,7 +54,7 @@ export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useT5EncoderModels = buildModelsHook(isT5EncoderModelConfig);
|
||||
export const useClipEmbedModels = buildModelsHook(isClipEmbedModelConfig);
|
||||
export const useCLIPEmbedModels = buildModelsHook(isCLIPEmbedModelConfig);
|
||||
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
|
||||
@@ -52,7 +52,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type CLIPEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type T5EncoderModelConfig = S['T5EncoderConfig'];
|
||||
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
@@ -68,7 +68,7 @@ export type AnyModelConfig =
|
||||
| IPAdapterModelConfig
|
||||
| T5EncoderModelConfig
|
||||
| T5EncoderBnbQuantizedLlmInt8bModelConfig
|
||||
| ClipEmbedModelConfig
|
||||
| CLIPEmbedModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| SpandrelImageToImageModelConfig
|
||||
| TextualInversionModelConfig
|
||||
@@ -105,7 +105,7 @@ export const isT5EncoderModelConfig = (
|
||||
return config.type === 't5_encoder';
|
||||
};
|
||||
|
||||
export const isClipEmbedModelConfig = (config: AnyModelConfig): config is ClipEmbedModelConfig => {
|
||||
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
|
||||
return config.type === 'clip_embed';
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user