feat(ui,api): add guidance as a default setting option for FLUX models

This commit is contained in:
Mary Hipp
2024-09-26 12:10:48 -04:00
committed by Mary Hipp Rogers
parent ca55ef1da5
commit c224971cb4
7 changed files with 119 additions and 5 deletions

View File

@@ -157,6 +157,7 @@ class MainModelDefaultSettings(BaseModel):
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
model_config = ConfigDict(extra="forbid")

View File

@@ -4,6 +4,7 @@ import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaS
import {
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
setScheduler,
setSteps,
vaePrecisionChanged,
@@ -13,6 +14,7 @@ import { setDefaultSettings } from 'features/parameters/store/actions';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterGuidance,
isParameterHeight,
isParameterPrecision,
isParameterScheduler,
@@ -49,7 +51,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height, guidance } =
modelConfig.default_settings;
if (vae) {
@@ -73,6 +75,12 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
if (guidance) {
if (isParameterGuidance(guidance)) {
dispatch(setGuidance(guidance));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));

View File

@@ -7,6 +7,7 @@ import type { MainModelConfig } from 'services/api/types';
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
const { guidance: fluxGuidance } = config.flux;
return {
initialSteps: steps.initial,
@@ -16,6 +17,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
initialVaePrecision: vaePrecision,
initialWidth: width.initial,
initialHeight: height.initial,
initialGuidance: fluxGuidance.initial,
};
});
@@ -28,6 +30,7 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
initialVaePrecision,
initialWidth,
initialHeight,
initialGuidance,
} = useAppSelector(initialStatesSelector);
const defaultSettingsDefaults = useMemo(() => {
@@ -64,6 +67,10 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
isEnabled: !isNil(modelConfig?.default_settings?.height),
value: modelConfig?.default_settings?.height || initialHeight,
},
guidance: {
isEnabled: !isNil(modelConfig?.default_settings?.guidance),
value: modelConfig?.default_settings?.guidance || initialGuidance,
},
};
}, [
modelConfig,
@@ -74,6 +81,7 @@ export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
initialCfgRescaleMultiplier,
initialWidth,
initialHeight,
initialGuidance,
]);
return defaultSettingsDefaults;

View File

@@ -0,0 +1,82 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { selectGuidanceConfig } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
type DefaultGuidanceType = MainModelDefaultSettingsFormData['guidance'];
export const DefaultGuidance = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
const { field } = useController(props);
const config = useAppSelector(selectGuidanceConfig);
const { t } = useTranslation();
const marks = useMemo(
() => [
config.sliderMin,
Math.floor(config.sliderMax - (config.sliderMax - config.sliderMin) / 2),
config.sliderMax,
],
[config.sliderMax, config.sliderMin]
);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultGuidanceType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultGuidanceType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultGuidanceType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="paramGuidance">
<FormLabel>{t('parameters.guidance')}</FormLabel>
</InformationalPopover>
<SettingToggle control={props.control} name="guidance" />
</Flex>
<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={config.sliderMin}
max={config.sliderMax}
step={config.coarseStep}
fineStep={config.fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={config.numberInputMin}
max={config.numberInputMax}
step={config.coarseStep}
fineStep={config.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});
DefaultGuidance.displayName = 'DefaultGuidance';

View File

@@ -17,6 +17,7 @@ import type { MainModelConfig } from 'services/api/types';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
import { DefaultGuidance } from './DefaultGuidance';
import { DefaultScheduler } from './DefaultScheduler';
import { DefaultSteps } from './DefaultSteps';
import { DefaultVae } from './DefaultVae';
@@ -36,6 +37,7 @@ export type MainModelDefaultSettingsFormData = {
cfgRescaleMultiplier: FormField<number>;
width: FormField<number>;
height: FormField<number>;
guidance: FormField<number>;
};
type Props = {
@@ -46,6 +48,10 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
const selectedModelKey = useAppSelector(selectSelectedModelKey);
const { t } = useTranslation();
const isFlux = useMemo(() => {
return modelConfig.base === 'flux';
}, [modelConfig]);
const defaultSettingsDefaults = useMainModelDefaultSettings(modelConfig);
const optimalDimension = useMemo(() => {
const modelBase = modelConfig?.base;
@@ -72,6 +78,7 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
width: data.width.isEnabled ? data.width.value : null,
height: data.height.isEnabled ? data.height.value : null,
guidance: data.guidance.isEnabled ? data.guidance.value : null,
};
updateModel({
@@ -118,11 +125,12 @@ export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
<SimpleGrid columns={2} gap={8}>
<DefaultVae control={control} name="vae" />
<DefaultVaePrecision control={control} name="vaePrecision" />
<DefaultScheduler control={control} name="scheduler" />
{!isFlux && <DefaultVaePrecision control={control} name="vaePrecision" />}
{!isFlux && <DefaultScheduler control={control} name="scheduler" />}
<DefaultSteps control={control} name="steps" />
<DefaultCfgScale control={control} name="cfgScale" />
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
{isFlux && <DefaultGuidance control={control} name="guidance" />}
{!isFlux && <DefaultCfgScale control={control} name="cfgScale" />}
{!isFlux && <DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />}
<DefaultWidth control={control} optimalDimension={optimalDimension} />
<DefaultHeight control={control} optimalDimension={optimalDimension} />
</SimpleGrid>

View File

@@ -59,6 +59,8 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
// #region Guidance parameter
const zParameterGuidance = z.number().min(1);
export type ParameterGuidance = z.infer<typeof zParameterGuidance>;
export const isParameterGuidance = (val: unknown): val is ParameterGuidance =>
zParameterGuidance.safeParse(val).success;
// #endregion
// #region CFG Rescale Multiplier

View File

@@ -11037,6 +11037,11 @@ export type components = {
* @description Default height for this model
*/
height?: number | null;
/**
* Guidance
* @description Default Guidance for this model
*/
guidance?: number | null;
};
/**
* Main Model