feat(ui): new fields and param layout for FLUX settings: guidance, t5 encoder, CLIP embed

This commit is contained in:
Mary Hipp
2024-09-11 14:25:27 -04:00
committed by psychedelicious
parent 8916036ed3
commit ffbf4aba1f
20 changed files with 339 additions and 63 deletions

View File

@@ -907,6 +907,7 @@
"downloadImage": "Download Image",
"general": "General",
"globalSettings": "Global Settings",
"guidance": "Guidance",
"height": "Height",
"imageFit": "Fit Initial Image To Output Size",
"images": "Images",
@@ -929,6 +930,8 @@
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
"canvasManagerNotLoaded": "Canvas Manager not loaded",
"canvasIsFiltering": "Canvas is filtering",
"canvasIsTransforming": "Canvas is transforming",

View File

@@ -10,7 +10,7 @@ import {
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { clipEmbedModelSelected, modelChanged, refinerModelChanged, t5EncoderModelSelected, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
@@ -21,12 +21,14 @@ import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlNetOrT2IAdapterModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
isVAEModelConfig,
} from 'services/api/types';
@@ -50,6 +52,8 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleControlAdapterModels(models, state, dispatch, log);
handleSpandrelImageToImageModels(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
handleT5EncoderModels(models, state, dispatch, log)
handleCLIPEmbedModels(models, state, dispatch, log)
},
});
};
@@ -223,3 +227,31 @@ const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch,
dispatch(postProcessingModelChanged(firstModel));
}
};
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, _log) => {
const { t5EncoderModel: currentT5EncoderModel } = state.params;
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
const firstModel = t5EncoderModels[0] || null;
const isCurrentT5EncoderModelAvailable = currentT5EncoderModel
? t5EncoderModels.some((m) => m.key === currentT5EncoderModel.key)
: false;
if (!isCurrentT5EncoderModelAvailable) {
dispatch(t5EncoderModelSelected(firstModel));
}
};
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, _log) => {
const { clipEmbedModel: currentCLIPEmbedModel } = state.params;
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
const firstModel = CLIPEmbedModels[0] || null;
const isCurrentCLIPEmbedModelAvailable = currentCLIPEmbedModel
? CLIPEmbedModels.some((m) => m.key === currentCLIPEmbedModel.key)
: false;
if (!isCurrentCLIPEmbedModelAvailable) {
dispatch(clipEmbedModelSelected(firstModel));
}
};

View File

@@ -114,6 +114,9 @@ export type AppConfig = {
weight: NumericalParameterConfig;
};
};
flux: {
guidance: NumericalParameterConfig
}
};
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;

View File

@@ -147,6 +147,16 @@ const createSelector = (
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (model?.base === 'flux') {
console.log({ params })
if (!params.t5EncoderModel) {
reasons.push({ content: i18n.t('parameters.invoke.noT5EncoderModelSelected') });
}
if (!params.clipEmbedModel) {
reasons.push({ content: i18n.t('parameters.invoke.noCLIPEmbedModelSelected') });
}
}
canvas.controlLayers.entities
.filter((controlLayer) => controlLayer.isEnabled)
.forEach((controlLayer, i) => {

View File

@@ -20,6 +20,9 @@ import type {
ParameterSteps,
ParameterStrength,
ParameterVAEModel,
ParameterGuidance,
ParameterT5EncoderModel,
ParameterCLIPEmbedModel
} from 'features/parameters/types/parameterSchemas';
import { clamp } from 'lodash-es';
@@ -35,6 +38,7 @@ export type ParamsState = {
infillColorValue: RgbaColor;
cfgScale: ParameterCFGScale;
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
guidance: ParameterGuidance;
img2imgStrength: ParameterStrength;
iterations: number;
scheduler: ParameterScheduler;
@@ -60,6 +64,8 @@ export type ParamsState = {
refinerPositiveAestheticScore: number;
refinerNegativeAestheticScore: number;
refinerStart: number;
t5EncoderModel: ParameterT5EncoderModel | null,
clipEmbedModel: ParameterCLIPEmbedModel | null
};
const initialState: ParamsState = {
@@ -74,6 +80,7 @@ const initialState: ParamsState = {
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
guidance: 4,
img2imgStrength: 0.75,
iterations: 1,
scheduler: 'euler',
@@ -99,6 +106,8 @@ const initialState: ParamsState = {
refinerPositiveAestheticScore: 6,
refinerNegativeAestheticScore: 2.5,
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null
};
export const paramsSlice = createSlice({
@@ -114,6 +123,9 @@ export const paramsSlice = createSlice({
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
state.cfgScale = action.payload;
},
setGuidance: (state, action: PayloadAction<ParameterGuidance>) => {
state.guidance = action.payload;
},
setCfgRescaleMultiplier: (state, action: PayloadAction<ParameterCFGRescaleMultiplier>) => {
state.cfgRescaleMultiplier = action.payload;
},
@@ -161,6 +173,12 @@ export const paramsSlice = createSlice({
// null is a valid VAE!
state.vae = action.payload;
},
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;
},
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload;
},
@@ -246,6 +264,7 @@ export const {
setSteps,
setCfgScale,
setCfgRescaleMultiplier,
setGuidance,
setScheduler,
setSeed,
setImg2imgStrength,
@@ -254,6 +273,8 @@ export const {
setShouldRandomizeSeed,
vaeSelected,
vaePrecisionChanged,
t5EncoderModelSelected,
clipEmbedModelSelected,
setClipSkip,
shouldUseCpuNoiseChanged,
positivePromptChanged,
@@ -289,11 +310,16 @@ export const createParamsSelector = <T>(selector: Selector<ParamsState, T>) =>
export const selectBase = createParamsSelector((params) => params.model?.base);
export const selectIsSDXL = createParamsSelector((params) => params.model?.base === 'sdxl');
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);
export const selectVAE = createParamsSelector((params) => params.vae);
export const selectVAEKey = createParamsSelector((params) => params.vae?.key);
export const selectT5EncoderModel = createParamsSelector((params) => params.t5EncoderModel);
export const selectCLIPEmbedModel = createParamsSelector((params) => params.clipEmbedModel);
export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
export const selectGuidance = createParamsSelector((params) => params.guidance);
export const selectSteps = createParamsSelector((params) => params.steps);
export const selectCFGRescaleMultiplier = createParamsSelector((params) => params.cfgRescaleMultiplier);
export const selectCLIPSKip = createParamsSelector((params) => params.clipSkip);

View File

@@ -19,6 +19,7 @@ import {
setScheduler,
setSeed,
setSteps,
t5EncoderModelSelected,
vaeSelected,
} from 'features/controlLayers/store/paramsSlice';
import type { LoRA } from 'features/controlLayers/store/types';
@@ -44,6 +45,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterT5EncoderModel,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
@@ -154,6 +156,10 @@ const recallVAE: MetadataRecallFunc<ParameterVAEModel | null | undefined> = (vae
getStore().dispatch(vaeSelected(vaeModel));
};
const recallT5Encoder: MetadataRecallFunc<ParameterT5EncoderModel> = (t5EncoderModel) => {
getStore().dispatch(t5EncoderModelSelected(t5EncoderModel));
};
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
getStore().dispatch(loraRecalled({ lora }));
};
@@ -196,4 +202,5 @@ export const recallers = {
vae: recallVAE,
lora: recallLoRA,
loras: recallAllLoRAs,
t5EncoderModel: recallT5Encoder
} as const;

View File

@@ -7,7 +7,7 @@ import type {
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { ParameterSDXLRefinerModel, ParameterT5EncoderModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { BaseModelType } from 'services/api/types';
/**
@@ -40,6 +40,13 @@ const validateVAEModel: MetadataValidateFunc<ParameterVAEModel> = (vaeModel) =>
});
};
const validateT5EncoderModel: MetadataValidateFunc<ParameterT5EncoderModel> = (t5EncoderModel) => {
validateBaseCompatibility('flux', 'T5 Encoder incompatible with currently-selected model');
return new Promise((resolve) => {
resolve(t5EncoderModel);
});
};
const validateLoRA: MetadataValidateFunc<LoRA> = (lora) => {
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
return new Promise((resolve) => {
@@ -131,6 +138,7 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
export const validators = {
refinerModel: validateRefinerModel,
vaeModel: validateVAEModel,
t5EncoderModel: validateT5EncoderModel,
lora: validateLoRA,
loras: validateLoRAs,
controlNet: validateControlNet,

View File

@@ -9,7 +9,7 @@ import {
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useClipEmbedModels,
useCLIPEmbedModels,
useControlNetModels,
useEmbeddingModels,
useIPAdapterModels,
@@ -85,7 +85,7 @@ const ModelList = () => {
[t5EncoderModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useClipEmbedModels();
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels();
const filteredClipEmbedModels = useMemo(
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
[clipEmbedModels, searchTerm, filteredModelType]

View File

@@ -5,8 +5,8 @@ import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
import type { ClipEmbedModelConfig } from 'services/api/types';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -17,9 +17,9 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useClipEmbedModels();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: ClipEmbedModelConfig | null) => {
(value: CLIPEmbedModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -0,0 +1,41 @@
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { clipEmbedModelSelected, selectCLIPEmbedModel } from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPEmbedModelConfig } from 'services/api/types';
import { useModelCombobox } from '../../../../common/hooks/useModelCombobox';
const ParamCLIPEmbedModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const clipEmbedModel = useAppSelector(selectCLIPEmbedModel);
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
if (clipEmbedModel) {
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
}
},
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: clipEmbedModel,
isLoading,
});
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
<FormLabel m={0}>{t('modelManager.clipEmbed')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
</FormControl>
);
};
export default memo(ParamCLIPEmbedModelSelect);

View File

@@ -0,0 +1,41 @@
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectT5EncoderModel, t5EncoderModelSelected } from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
import { useModelCombobox } from '../../../../common/hooks/useModelCombobox';
const ParamT5EncoderModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const t5EncoderModel = useAppSelector(selectT5EncoderModel);
const [modelConfigs, { isLoading }] = useT5EncoderModels();
const _onChange = useCallback(
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
if (t5EncoderModel) {
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
}
},
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: t5EncoderModel,
isLoading,
});
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
<FormLabel m={0}>{t('modelManager.t5Encoder')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
</FormControl>
);
};
export default memo(ParamT5EncoderModelSelect);

View File

@@ -0,0 +1,49 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectGuidance, setGuidance } from 'features/controlLayers/store/paramsSlice';
import { selectGuidanceConfig } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const ParamGuidance = () => {
const guidance = useAppSelector(selectGuidance);
const config = useAppSelector(selectGuidanceConfig);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const marks = useMemo(
() => [
config.sliderMin,
Math.floor(config.sliderMax - (config.sliderMax - config.sliderMin) / 2),
config.sliderMax,
],
[config.sliderMax, config.sliderMin]
);
const onChange = useCallback((v: number) => dispatch(setGuidance(v)), [dispatch]);
return (
<FormControl>
<FormLabel>{t('parameters.guidance')}</FormLabel>
<CompositeSlider
value={guidance}
defaultValue={config.initial}
min={config.sliderMin}
max={config.sliderMax}
step={config.coarseStep}
fineStep={config.fineStep}
onChange={onChange}
marks={marks}
/>
<CompositeNumberInput
value={guidance}
defaultValue={config.initial}
min={config.numberInputMin}
max={config.numberInputMax}
step={config.coarseStep}
fineStep={config.fineStep}
onChange={onChange}
/>
</FormControl>
);
};
export default memo(ParamGuidance);

View File

@@ -7,14 +7,14 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { modelSelected } from 'features/parameters/store/actions';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSDMainModels } from 'services/api/hooks/modelsByType';
import { useMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModel = useAppSelector(selectModel);
const [modelConfigs, { isLoading }] = useSDMainModels();
const [modelConfigs, { isLoading }] = useMainModels();
const tooltipLabel = useMemo(() => {
if (!modelConfigs.length || !selectedModel) {
return;

View File

@@ -1,6 +1,6 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { createParamsSelector } from 'features/controlLayers/store/paramsSlice';
import { createParamsSelector, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
@@ -15,11 +15,12 @@ const selectWithStylePrompts = createParamsSelector((params) => {
export const Prompts = memo(() => {
const withStylePrompts = useAppSelector(selectWithStylePrompts);
const isFLUX = useAppSelector(selectIsFLUX);
return (
<Flex flexDir="column" gap={2}>
<ParamPositivePrompt />
{withStylePrompts && <ParamSDXLPositiveStylePrompt />}
<ParamNegativePrompt />
{!isFLUX && <ParamNegativePrompt />}
{withStylePrompts && <ParamSDXLNegativeStylePrompt />}
</Flex>
);

View File

@@ -56,6 +56,13 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
zParameterCFGScale.safeParse(val).success;
// #endregion
// #region Guidance parameter
const zParameterGuidance = z.number().min(1);
export type ParameterGuidance = z.infer<typeof zParameterGuidance>;
export const isParameterGuidance = (val: unknown): val is ParameterGuidance =>
zParameterGuidance.safeParse(val).success;
// #endregion
// #region CFG Rescale Multiplier
const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
export type ParameterCFGRescaleMultiplier = z.infer<typeof zParameterCFGRescaleMultiplier>;
@@ -106,6 +113,16 @@ export const zParameterVAEModel = zModelIdentifierField;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
// #endregion
// #region T5Encoder Model
export const zParameterT5EncoderModel = zModelIdentifierField;
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;
// #endregion
// #region CLIP embed Model
export const zParameterCLIPEmbedModel = zModelIdentifierField;
export type ParameterCLIPEmbedModel = z.infer<typeof zParameterCLIPEmbedModel>;
// #endregion
// #region LoRA Model
const zParameterLoRAModel = zModelIdentifierField;
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;

View File

@@ -3,7 +3,7 @@ import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-libra
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
import ParamClipSkip from 'features/parameters/components/Advanced/ParamClipSkip';
import ParamSeamlessXAxis from 'features/parameters/components/Seamless/ParamSeamlessXAxis';
@@ -18,6 +18,8 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import ParamT5EncoderModelSelect from '../../../parameters/components/Advanced/ParamT5EncoderModelSelect';
import ParamCLIPEmbedModelSelect from '../../../parameters/components/Advanced/ParamCLIPEmbedModelSelect';
const formLabelProps: FormLabelProps = {
minW: '9.2rem',
@@ -31,32 +33,44 @@ export const AdvancedSettingsAccordion = memo(() => {
const vaeKey = useAppSelector(selectVAEKey);
const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
const activeTabName = useAppSelector(selectActiveTab);
const isFLUX = useAppSelector(selectIsFLUX);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectParamsSlice, (params) => {
createMemoizedSelector([selectParamsSlice, selectIsFLUX], (params, isFLUX) => {
const badges: (string | number)[] = [];
if (vaeConfig) {
let vaeBadge = vaeConfig.name;
if (params.vaePrecision === 'fp16') {
vaeBadge += ` ${params.vaePrecision}`;
if (isFLUX) {
if (vaeConfig) {
let vaeBadge = vaeConfig.name;
if (params.vaePrecision === 'fp16') {
vaeBadge += ` ${params.vaePrecision}`;
}
badges.push(vaeBadge);
}
} else {
if (vaeConfig) {
let vaeBadge = vaeConfig.name;
if (params.vaePrecision === 'fp16') {
vaeBadge += ` ${params.vaePrecision}`;
}
badges.push(vaeBadge);
} else if (params.vaePrecision === 'fp16') {
badges.push(`VAE ${params.vaePrecision}`);
}
if (params.clipSkip) {
badges.push(`Skip ${params.clipSkip}`);
}
if (params.cfgRescaleMultiplier) {
badges.push(`Rescale ${params.cfgRescaleMultiplier}`);
}
if (params.seamlessXAxis || params.seamlessYAxis) {
badges.push('seamless');
}
if (activeTabName === 'upscaling' && !params.shouldRandomizeSeed) {
badges.push('Manual Seed');
}
badges.push(vaeBadge);
} else if (params.vaePrecision === 'fp16') {
badges.push(`VAE ${params.vaePrecision}`);
}
if (params.clipSkip) {
badges.push(`Skip ${params.clipSkip}`);
}
if (params.cfgRescaleMultiplier) {
badges.push(`Rescale ${params.cfgRescaleMultiplier}`);
}
if (params.seamlessXAxis || params.seamlessYAxis) {
badges.push('seamless');
}
if (activeTabName === 'upscaling' && !params.shouldRandomizeSeed) {
badges.push('Manual Seed');
}
return badges;
}),
[vaeConfig, activeTabName]
@@ -73,27 +87,36 @@ export const AdvancedSettingsAccordion = memo(() => {
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
<Flex gap={4} w="full">
<ParamVAEModelSelect />
<ParamVAEPrecision />
{!isFLUX && <ParamVAEPrecision />}
</Flex>
{activeTabName === 'upscaling' && (
{activeTabName === 'upscaling' ? (
<Flex gap={4} alignItems="center">
<ParamSeedNumberInput />
<ParamSeedShuffle />
<ParamSeedRandomize />
</Flex>
)}
{activeTabName !== 'upscaling' && (
) : (
<>
<FormControlGroup formLabelProps={formLabelProps}>
<ParamClipSkip />
<ParamCFGRescaleMultiplier />
</FormControlGroup>
<Flex gap={4} w="full">
<FormControlGroup formLabelProps={formLabelProps2}>
<ParamSeamlessXAxis />
<ParamSeamlessYAxis />
{!isFLUX && (
<>
<FormControlGroup formLabelProps={formLabelProps}>
<ParamClipSkip />
<ParamCFGRescaleMultiplier />
</FormControlGroup>
<Flex gap={4} w="full">
<FormControlGroup formLabelProps={formLabelProps2}>
<ParamSeamlessXAxis />
<ParamSeamlessYAxis />
</FormControlGroup>
</Flex>
</>
)}
{isFLUX && (
<FormControlGroup formLabelProps={formLabelProps}>
<ParamT5EncoderModelSelect />
<ParamCLIPEmbedModelSelect />
</FormControlGroup>
</Flex>
)}
</>
)}
</Flex>

View File

@@ -18,6 +18,8 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
import ParamGuidance from '../../../parameters/components/Core/ParamGuidance';
import { selectIsFLUX } from '../../../controlLayers/store/paramsSlice';
const formLabelProps: FormLabelProps = {
minW: '4rem',
@@ -27,6 +29,7 @@ export const GenerationSettingsAccordion = memo(() => {
const { t } = useTranslation();
const modelConfig = useSelectedModelConfig();
const activeTabName = useAppSelector(selectActiveTab);
const isFLUX = useAppSelector(selectIsFLUX);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectLoRAsSlice, (loras) => {
@@ -71,9 +74,9 @@ export const GenerationSettingsAccordion = memo(() => {
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} flexDir="column" pb={4}>
<FormControlGroup formLabelProps={formLabelProps}>
<ParamScheduler />
{!isFLUX && <ParamScheduler />}
<ParamSteps />
<ParamCFGScale />
{isFLUX ? <ParamGuidance /> : <ParamCFGScale />}
</FormControlGroup>
</Flex>
</Expander>

View File

@@ -167,6 +167,17 @@ const initialConfigState: AppConfig = {
},
},
},
flux: {
guidance: {
initial: 4,
sliderMin: 2,
sliderMax: 6,
numberInputMin: 2,
numberInputMax: 6,
fineStep: 0.1,
coarseStep: 1,
}
}
};
export const configSlice = createSlice({
@@ -188,6 +199,7 @@ export const selectWidthConfig = createConfigSelector((config) => config.sd.widt
export const selectHeightConfig = createConfigSelector((config) => config.sd.height);
export const selectStepsConfig = createConfigSelector((config) => config.sd.steps);
export const selectCFGScaleConfig = createConfigSelector((config) => config.sd.guidance);
export const selectGuidanceConfig = createConfigSelector((config) => config.flux.guidance);
export const selectCLIPSkipConfig = createConfigSelector((config) => config.sd.clipSkip);
export const selectCFGRescaleMultiplierConfig = createConfigSelector((config) => config.sd.cfgRescaleMultiplier);
export const selectCanvasCoherenceEdgeSizeConfig = createConfigSelector((config) => config.sd.canvasCoherenceEdgeSize);

View File

@@ -9,7 +9,7 @@ import {
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isClipEmbedModelConfig,
isCLIPEmbedModelConfig,
isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig,
@@ -30,18 +30,18 @@ import {
const buildModelsHook =
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
() => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
return EMPTY_ARRAY;
}
() => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
}, [result]);
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
}, [result]);
return [modelConfigs, result] as const;
};
return [modelConfigs, result] as const;
};
export const useSDMainModels = buildModelsHook(isNonRefinerNonFluxMainModelConfig);
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
@@ -54,7 +54,7 @@ export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useT5EncoderModels = buildModelsHook(isT5EncoderModelConfig);
export const useClipEmbedModels = buildModelsHook(isClipEmbedModelConfig);
export const useCLIPEmbedModels = buildModelsHook(isCLIPEmbedModelConfig);
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);

View File

@@ -52,7 +52,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type CLIPEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type T5EncoderModelConfig = S['T5EncoderConfig'];
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
@@ -68,7 +68,7 @@ export type AnyModelConfig =
| IPAdapterModelConfig
| T5EncoderModelConfig
| T5EncoderBnbQuantizedLlmInt8bModelConfig
| ClipEmbedModelConfig
| CLIPEmbedModelConfig
| T2IAdapterModelConfig
| SpandrelImageToImageModelConfig
| TextualInversionModelConfig
@@ -105,7 +105,7 @@ export const isT5EncoderModelConfig = (
return config.type === 't5_encoder';
};
export const isClipEmbedModelConfig = (config: AnyModelConfig): config is ClipEmbedModelConfig => {
export const isCLIPEmbedModelConfig = (config: AnyModelConfig): config is CLIPEmbedModelConfig => {
return config.type === 'clip_embed';
};