mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-16 09:15:59 -05:00
feat(ui): new fields and param layout for FLUX settings: guidance, t5 encoder, CLIP embed
This commit is contained in:
committed by
psychedelicious
parent
8916036ed3
commit
ffbf4aba1f
@@ -20,6 +20,9 @@ import type {
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterVAEModel,
|
||||
ParameterGuidance,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterCLIPEmbedModel
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { clamp } from 'lodash-es';
|
||||
|
||||
@@ -35,6 +38,7 @@ export type ParamsState = {
|
||||
infillColorValue: RgbaColor;
|
||||
cfgScale: ParameterCFGScale;
|
||||
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||
guidance: ParameterGuidance;
|
||||
img2imgStrength: ParameterStrength;
|
||||
iterations: number;
|
||||
scheduler: ParameterScheduler;
|
||||
@@ -60,6 +64,8 @@ export type ParamsState = {
|
||||
refinerPositiveAestheticScore: number;
|
||||
refinerNegativeAestheticScore: number;
|
||||
refinerStart: number;
|
||||
t5EncoderModel: ParameterT5EncoderModel | null,
|
||||
clipEmbedModel: ParameterCLIPEmbedModel | null
|
||||
};
|
||||
|
||||
const initialState: ParamsState = {
|
||||
@@ -74,6 +80,7 @@ const initialState: ParamsState = {
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
cfgScale: 7.5,
|
||||
cfgRescaleMultiplier: 0,
|
||||
guidance: 4,
|
||||
img2imgStrength: 0.75,
|
||||
iterations: 1,
|
||||
scheduler: 'euler',
|
||||
@@ -99,6 +106,8 @@ const initialState: ParamsState = {
|
||||
refinerPositiveAestheticScore: 6,
|
||||
refinerNegativeAestheticScore: 2.5,
|
||||
refinerStart: 0.8,
|
||||
t5EncoderModel: null,
|
||||
clipEmbedModel: null
|
||||
};
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
@@ -114,6 +123,9 @@ export const paramsSlice = createSlice({
|
||||
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
|
||||
state.cfgScale = action.payload;
|
||||
},
|
||||
setGuidance: (state, action: PayloadAction<ParameterGuidance>) => {
|
||||
state.guidance = action.payload;
|
||||
},
|
||||
setCfgRescaleMultiplier: (state, action: PayloadAction<ParameterCFGRescaleMultiplier>) => {
|
||||
state.cfgRescaleMultiplier = action.payload;
|
||||
},
|
||||
@@ -161,6 +173,12 @@ export const paramsSlice = createSlice({
|
||||
// null is a valid VAE!
|
||||
state.vae = action.payload;
|
||||
},
|
||||
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
||||
state.t5EncoderModel = action.payload;
|
||||
},
|
||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||
state.clipEmbedModel = action.payload;
|
||||
},
|
||||
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
|
||||
state.vaePrecision = action.payload;
|
||||
},
|
||||
@@ -246,6 +264,7 @@ export const {
|
||||
setSteps,
|
||||
setCfgScale,
|
||||
setCfgRescaleMultiplier,
|
||||
setGuidance,
|
||||
setScheduler,
|
||||
setSeed,
|
||||
setImg2imgStrength,
|
||||
@@ -254,6 +273,8 @@ export const {
|
||||
setShouldRandomizeSeed,
|
||||
vaeSelected,
|
||||
vaePrecisionChanged,
|
||||
t5EncoderModelSelected,
|
||||
clipEmbedModelSelected,
|
||||
setClipSkip,
|
||||
shouldUseCpuNoiseChanged,
|
||||
positivePromptChanged,
|
||||
@@ -289,11 +310,16 @@ export const createParamsSelector = <T>(selector: Selector<ParamsState, T>) =>
|
||||
|
||||
export const selectBase = createParamsSelector((params) => params.model?.base);
|
||||
export const selectIsSDXL = createParamsSelector((params) => params.model?.base === 'sdxl');
|
||||
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
|
||||
|
||||
export const selectModel = createParamsSelector((params) => params.model);
|
||||
export const selectModelKey = createParamsSelector((params) => params.model?.key);
|
||||
export const selectVAE = createParamsSelector((params) => params.vae);
|
||||
export const selectVAEKey = createParamsSelector((params) => params.vae?.key);
|
||||
export const selectT5EncoderModel = createParamsSelector((params) => params.t5EncoderModel);
|
||||
export const selectCLIPEmbedModel = createParamsSelector((params) => params.clipEmbedModel);
|
||||
export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
|
||||
export const selectGuidance = createParamsSelector((params) => params.guidance);
|
||||
export const selectSteps = createParamsSelector((params) => params.steps);
|
||||
export const selectCFGRescaleMultiplier = createParamsSelector((params) => params.cfgRescaleMultiplier);
|
||||
export const selectCLIPSKip = createParamsSelector((params) => params.clipSkip);
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
setScheduler,
|
||||
setSeed,
|
||||
setSteps,
|
||||
t5EncoderModelSelected,
|
||||
vaeSelected,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import type { LoRA } from 'features/controlLayers/store/types';
|
||||
@@ -44,6 +45,7 @@ import type {
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@@ -154,6 +156,10 @@ const recallVAE: MetadataRecallFunc<ParameterVAEModel | null | undefined> = (vae
|
||||
getStore().dispatch(vaeSelected(vaeModel));
|
||||
};
|
||||
|
||||
const recallT5Encoder: MetadataRecallFunc<ParameterT5EncoderModel> = (t5EncoderModel) => {
|
||||
getStore().dispatch(t5EncoderModelSelected(t5EncoderModel));
|
||||
};
|
||||
|
||||
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||
getStore().dispatch(loraRecalled({ lora }));
|
||||
};
|
||||
@@ -196,4 +202,5 @@ export const recallers = {
|
||||
vae: recallVAE,
|
||||
lora: recallLoRA,
|
||||
loras: recallAllLoRAs,
|
||||
t5EncoderModel: recallT5Encoder
|
||||
} as const;
|
||||
|
||||
@@ -7,7 +7,7 @@ import type {
|
||||
T2IAdapterConfigMetadata,
|
||||
} from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ParameterSDXLRefinerModel, ParameterT5EncoderModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { BaseModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
@@ -40,6 +40,13 @@ const validateVAEModel: MetadataValidateFunc<ParameterVAEModel> = (vaeModel) =>
|
||||
});
|
||||
};
|
||||
|
||||
const validateT5EncoderModel: MetadataValidateFunc<ParameterT5EncoderModel> = (t5EncoderModel) => {
|
||||
validateBaseCompatibility('flux', 'T5 Encoder incompatible with currently-selected model');
|
||||
return new Promise((resolve) => {
|
||||
resolve(t5EncoderModel);
|
||||
});
|
||||
};
|
||||
|
||||
const validateLoRA: MetadataValidateFunc<LoRA> = (lora) => {
|
||||
validateBaseCompatibility(lora.model.base, 'LoRA incompatible with currently-selected model');
|
||||
return new Promise((resolve) => {
|
||||
@@ -131,6 +138,7 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
|
||||
export const validators = {
|
||||
refinerModel: validateRefinerModel,
|
||||
vaeModel: validateVAEModel,
|
||||
t5EncoderModel: validateT5EncoderModel,
|
||||
lora: validateLoRA,
|
||||
loras: validateLoRAs,
|
||||
controlNet: validateControlNet,
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useClipEmbedModels,
|
||||
useCLIPEmbedModels,
|
||||
useControlNetModels,
|
||||
useEmbeddingModels,
|
||||
useIPAdapterModels,
|
||||
@@ -85,7 +85,7 @@ const ModelList = () => {
|
||||
[t5EncoderModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useClipEmbedModels();
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels();
|
||||
const filteredClipEmbedModels = useMemo(
|
||||
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
|
||||
[clipEmbedModels, searchTerm, filteredModelType]
|
||||
|
||||
@@ -5,8 +5,8 @@ import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ClipEmbedModelConfig } from 'services/api/types';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -17,9 +17,9 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useClipEmbedModels();
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: ClipEmbedModelConfig | null) => {
|
||||
(value: CLIPEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { clipEmbedModelSelected, selectCLIPEmbedModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPEmbedModelConfig } from 'services/api/types';
|
||||
import { useModelCombobox } from '../../../../common/hooks/useModelCombobox';
|
||||
|
||||
const ParamCLIPEmbedModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const clipEmbedModel = useAppSelector(selectCLIPEmbedModel);
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(clipEmbedModel: CLIPEmbedModelConfig | null) => {
|
||||
if (clipEmbedModel) {
|
||||
dispatch(clipEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: clipEmbedModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
|
||||
<FormLabel m={0}>{t('modelManager.clipEmbed')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamCLIPEmbedModelSelect);
|
||||
@@ -0,0 +1,41 @@
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectT5EncoderModel, t5EncoderModelSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useT5EncoderModels } from 'services/api/hooks/modelsByType';
|
||||
import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types';
|
||||
import { useModelCombobox } from '../../../../common/hooks/useModelCombobox';
|
||||
|
||||
const ParamT5EncoderModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const t5EncoderModel = useAppSelector(selectT5EncoderModel);
|
||||
const [modelConfigs, { isLoading }] = useT5EncoderModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(t5EncoderModel: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => {
|
||||
if (t5EncoderModel) {
|
||||
dispatch(t5EncoderModelSelected(zModelIdentifierField.parse(t5EncoderModel)));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: t5EncoderModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
|
||||
<FormLabel m={0}>{t('modelManager.t5Encoder')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamT5EncoderModelSelect);
|
||||
@@ -0,0 +1,49 @@
|
||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectGuidance, setGuidance } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectGuidanceConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamGuidance = () => {
|
||||
const guidance = useAppSelector(selectGuidance);
|
||||
const config = useAppSelector(selectGuidanceConfig);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const marks = useMemo(
|
||||
() => [
|
||||
config.sliderMin,
|
||||
Math.floor(config.sliderMax - (config.sliderMax - config.sliderMin) / 2),
|
||||
config.sliderMax,
|
||||
],
|
||||
[config.sliderMax, config.sliderMin]
|
||||
);
|
||||
const onChange = useCallback((v: number) => dispatch(setGuidance(v)), [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('parameters.guidance')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={guidance}
|
||||
defaultValue={config.initial}
|
||||
min={config.sliderMin}
|
||||
max={config.sliderMax}
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
onChange={onChange}
|
||||
marks={marks}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={guidance}
|
||||
defaultValue={config.initial}
|
||||
min={config.numberInputMin}
|
||||
max={config.numberInputMax}
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
onChange={onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamGuidance);
|
||||
@@ -7,14 +7,14 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSDMainModels } from 'services/api/hooks/modelsByType';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const selectedModel = useAppSelector(selectModel);
|
||||
const [modelConfigs, { isLoading }] = useSDMainModels();
|
||||
const [modelConfigs, { isLoading }] = useMainModels();
|
||||
const tooltipLabel = useMemo(() => {
|
||||
if (!modelConfigs.length || !selectedModel) {
|
||||
return;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { createParamsSelector } from 'features/controlLayers/store/paramsSlice';
|
||||
import { createParamsSelector, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
|
||||
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
|
||||
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
|
||||
@@ -15,11 +15,12 @@ const selectWithStylePrompts = createParamsSelector((params) => {
|
||||
|
||||
export const Prompts = memo(() => {
|
||||
const withStylePrompts = useAppSelector(selectWithStylePrompts);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<ParamPositivePrompt />
|
||||
{withStylePrompts && <ParamSDXLPositiveStylePrompt />}
|
||||
<ParamNegativePrompt />
|
||||
{!isFLUX && <ParamNegativePrompt />}
|
||||
{withStylePrompts && <ParamSDXLNegativeStylePrompt />}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -56,6 +56,13 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
||||
zParameterCFGScale.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region Guidance parameter
|
||||
const zParameterGuidance = z.number().min(1);
|
||||
export type ParameterGuidance = z.infer<typeof zParameterGuidance>;
|
||||
export const isParameterGuidance = (val: unknown): val is ParameterGuidance =>
|
||||
zParameterGuidance.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region CFG Rescale Multiplier
|
||||
const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||
export type ParameterCFGRescaleMultiplier = z.infer<typeof zParameterCFGRescaleMultiplier>;
|
||||
@@ -106,6 +113,16 @@ export const zParameterVAEModel = zModelIdentifierField;
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
// #endregion
|
||||
|
||||
// #region T5Encoder Model
|
||||
export const zParameterT5EncoderModel = zModelIdentifierField;
|
||||
export type ParameterT5EncoderModel = z.infer<typeof zParameterT5EncoderModel>;
|
||||
// #endregion
|
||||
|
||||
// #region CLIP embed Model
|
||||
export const zParameterCLIPEmbedModel = zModelIdentifierField;
|
||||
export type ParameterCLIPEmbedModel = z.infer<typeof zParameterCLIPEmbedModel>;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA Model
|
||||
const zParameterLoRAModel = zModelIdentifierField;
|
||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-libra
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
|
||||
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
|
||||
import ParamClipSkip from 'features/parameters/components/Advanced/ParamClipSkip';
|
||||
import ParamSeamlessXAxis from 'features/parameters/components/Seamless/ParamSeamlessXAxis';
|
||||
@@ -18,6 +18,8 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import ParamT5EncoderModelSelect from '../../../parameters/components/Advanced/ParamT5EncoderModelSelect';
|
||||
import ParamCLIPEmbedModelSelect from '../../../parameters/components/Advanced/ParamCLIPEmbedModelSelect';
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
minW: '9.2rem',
|
||||
@@ -31,32 +33,44 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
const vaeKey = useAppSelector(selectVAEKey);
|
||||
const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectParamsSlice, (params) => {
|
||||
createMemoizedSelector([selectParamsSlice, selectIsFLUX], (params, isFLUX) => {
|
||||
const badges: (string | number)[] = [];
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (params.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${params.vaePrecision}`;
|
||||
if (isFLUX) {
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (params.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${params.vaePrecision}`;
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
}
|
||||
} else {
|
||||
if (vaeConfig) {
|
||||
let vaeBadge = vaeConfig.name;
|
||||
if (params.vaePrecision === 'fp16') {
|
||||
vaeBadge += ` ${params.vaePrecision}`;
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
} else if (params.vaePrecision === 'fp16') {
|
||||
badges.push(`VAE ${params.vaePrecision}`);
|
||||
}
|
||||
if (params.clipSkip) {
|
||||
badges.push(`Skip ${params.clipSkip}`);
|
||||
}
|
||||
if (params.cfgRescaleMultiplier) {
|
||||
badges.push(`Rescale ${params.cfgRescaleMultiplier}`);
|
||||
}
|
||||
if (params.seamlessXAxis || params.seamlessYAxis) {
|
||||
badges.push('seamless');
|
||||
}
|
||||
if (activeTabName === 'upscaling' && !params.shouldRandomizeSeed) {
|
||||
badges.push('Manual Seed');
|
||||
}
|
||||
badges.push(vaeBadge);
|
||||
} else if (params.vaePrecision === 'fp16') {
|
||||
badges.push(`VAE ${params.vaePrecision}`);
|
||||
}
|
||||
if (params.clipSkip) {
|
||||
badges.push(`Skip ${params.clipSkip}`);
|
||||
}
|
||||
if (params.cfgRescaleMultiplier) {
|
||||
badges.push(`Rescale ${params.cfgRescaleMultiplier}`);
|
||||
}
|
||||
if (params.seamlessXAxis || params.seamlessYAxis) {
|
||||
badges.push('seamless');
|
||||
}
|
||||
if (activeTabName === 'upscaling' && !params.shouldRandomizeSeed) {
|
||||
badges.push('Manual Seed');
|
||||
}
|
||||
|
||||
return badges;
|
||||
}),
|
||||
[vaeConfig, activeTabName]
|
||||
@@ -73,27 +87,36 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
|
||||
<Flex gap={4} w="full">
|
||||
<ParamVAEModelSelect />
|
||||
<ParamVAEPrecision />
|
||||
{!isFLUX && <ParamVAEPrecision />}
|
||||
</Flex>
|
||||
{activeTabName === 'upscaling' && (
|
||||
{activeTabName === 'upscaling' ? (
|
||||
<Flex gap={4} alignItems="center">
|
||||
<ParamSeedNumberInput />
|
||||
<ParamSeedShuffle />
|
||||
<ParamSeedRandomize />
|
||||
</Flex>
|
||||
)}
|
||||
{activeTabName !== 'upscaling' && (
|
||||
) : (
|
||||
<>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamClipSkip />
|
||||
<ParamCFGRescaleMultiplier />
|
||||
</FormControlGroup>
|
||||
<Flex gap={4} w="full">
|
||||
<FormControlGroup formLabelProps={formLabelProps2}>
|
||||
<ParamSeamlessXAxis />
|
||||
<ParamSeamlessYAxis />
|
||||
{!isFLUX && (
|
||||
<>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamClipSkip />
|
||||
<ParamCFGRescaleMultiplier />
|
||||
</FormControlGroup>
|
||||
<Flex gap={4} w="full">
|
||||
<FormControlGroup formLabelProps={formLabelProps2}>
|
||||
<ParamSeamlessXAxis />
|
||||
<ParamSeamlessYAxis />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{isFLUX && (
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamT5EncoderModelSelect />
|
||||
<ParamCLIPEmbedModelSelect />
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
@@ -18,6 +18,8 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
|
||||
import ParamGuidance from '../../../parameters/components/Core/ParamGuidance';
|
||||
import { selectIsFLUX } from '../../../controlLayers/store/paramsSlice';
|
||||
|
||||
const formLabelProps: FormLabelProps = {
|
||||
minW: '4rem',
|
||||
@@ -27,6 +29,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoRAsSlice, (loras) => {
|
||||
@@ -71,9 +74,9 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamScheduler />
|
||||
{!isFLUX && <ParamScheduler />}
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
{isFLUX ? <ParamGuidance /> : <ParamCFGScale />}
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
</Expander>
|
||||
|
||||
@@ -167,6 +167,17 @@ const initialConfigState: AppConfig = {
|
||||
},
|
||||
},
|
||||
},
|
||||
flux: {
|
||||
guidance: {
|
||||
initial: 4,
|
||||
sliderMin: 2,
|
||||
sliderMax: 6,
|
||||
numberInputMin: 2,
|
||||
numberInputMax: 6,
|
||||
fineStep: 0.1,
|
||||
coarseStep: 1,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const configSlice = createSlice({
|
||||
@@ -188,6 +199,7 @@ export const selectWidthConfig = createConfigSelector((config) => config.sd.widt
|
||||
export const selectHeightConfig = createConfigSelector((config) => config.sd.height);
|
||||
export const selectStepsConfig = createConfigSelector((config) => config.sd.steps);
|
||||
export const selectCFGScaleConfig = createConfigSelector((config) => config.sd.guidance);
|
||||
export const selectGuidanceConfig = createConfigSelector((config) => config.flux.guidance);
|
||||
export const selectCLIPSkipConfig = createConfigSelector((config) => config.sd.clipSkip);
|
||||
export const selectCFGRescaleMultiplierConfig = createConfigSelector((config) => config.sd.cfgRescaleMultiplier);
|
||||
export const selectCanvasCoherenceEdgeSizeConfig = createConfigSelector((config) => config.sd.canvasCoherenceEdgeSize);
|
||||
|
||||
Reference in New Issue
Block a user