(ui): replace logic for controlnet/t2i to include control_loras and display default settings in model manager

This commit is contained in:
Mary Hipp
2024-12-12 21:27:08 -05:00
committed by Kent Keirsey
parent da213e4638
commit 92b0d89b70
13 changed files with 73 additions and 47 deletions

View File

@@ -30,7 +30,7 @@ import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlNetOrT2IAdapterModelConfig,
isControlLayerModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
@@ -190,7 +190,7 @@ const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
const caModels = models.filter(isControlLayerModelConfig);
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
const selectedControlAdapterModel = entity.controlAdapter.model;
// `null` is a valid control adapter model - no need to do anything.

View File

@@ -26,7 +26,12 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
ImageDTO,
T2IAdapterModelConfig,
} from 'services/api/types';
const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) =>
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
@@ -157,8 +162,10 @@ export const ControlLayerControlAdapter = memo(() => {
/>
<input {...uploadApi.getUploadInputProps()} />
</Flex>
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type !== 'control_lora' && <Weight weight={controlAdapter.weight} onChange={onChangeWeight} />}
{controlAdapter.type !== 'control_lora' && (
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
)}
{controlAdapter.type === 'controlnet' && !isFLUX && (
<ControlLayerControlAdapterControlMode
controlMode={controlAdapter.controlMode}

View File

@@ -5,7 +5,12 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlLayerModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type {
AnyModelConfig,
ControlLoRAModelConfig,
ControlNetModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
type Props = {
modelKey: string | null;

View File

@@ -18,6 +18,7 @@ import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/se
import type {
CanvasEntityIdentifier,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
@@ -26,8 +27,13 @@ import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
import { isControlLayerModelConfig, isIPAdapterModelConfig } from 'services/api/types';
/**
* Selects the default control adapter configuration based on the model configurations and the base.
@@ -39,13 +45,13 @@ import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'ser
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
(query, base): ControlNetConfig | T2IAdapterConfig => {
(query, base): ControlNetConfig | T2IAdapterConfig | ControlLoRAConfig => {
const { data } = query;
let model: ControlNetModelConfig | T2IAdapterModelConfig | null = null;
let model: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors
.selectAll(data)
.filter(isControlNetOrT2IAdapterModelConfig)
.filter(isControlLayerModelConfig)
.sort((a) => (a.type === 'controlnet' ? -1 : 1)); // Prefer ControlNet models
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;

View File

@@ -18,7 +18,7 @@ import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
import { isControlLayerModelConfig } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
@@ -204,7 +204,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
// If the parent is a control layer adapter, we should check if the model has a default filter and set it if so
const selectModelConfig = buildSelectModelConfig(
this.parent.state.controlAdapter.model.key,
isControlNetOrT2IAdapterModelConfig
isControlLayerModelConfig
);
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
// This always returns a filter

View File

@@ -34,7 +34,13 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
ImageDTO,
IPAdapterModelConfig,
T2IAdapterModelConfig,
} from 'services/api/types';
import { assert } from 'tsafe';
import type {
@@ -486,7 +492,7 @@ export const canvasSlice = createSlice({
) => {
const { entityIdentifier, weight } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer || !layer.controlAdapter) {
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
return;
}
layer.controlAdapter.weight = weight;
@@ -497,7 +503,7 @@ export const canvasSlice = createSlice({
) => {
const { entityIdentifier, beginEndStepPct } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer || !layer.controlAdapter) {
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;

View File

@@ -454,7 +454,9 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
* Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default
* filter for the model type will be used.
*/
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
export const getFilterForModel = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null
) => {
if (!modelConfig) {
// No model
return null;

View File

@@ -1,9 +1,9 @@
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const useControlNetOrT2IAdapterDefaultSettings = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
export const useControlAdapterModelDefaultSettings = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig
) => {
const defaultSettingsDefaults = useMemo(() => {
return {

View File

@@ -1,6 +1,6 @@
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
import { useControlAdapterModelDefaultSettings } from 'features/modelManagerV2/hooks/useControlAdapterModelDefaultSettings';
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/DefaultPreprocessor';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
@@ -9,28 +9,28 @@ import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { PiCheckBold } from 'react-icons/pi';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
export type ControlAdapterModelDefaultSettingsFormData = {
preprocessor: FormField<string>;
};
type Props = {
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig;
};
export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Props) => {
export const ControlAdapterModelDefaultSettings = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
const defaultSettingsDefaults = useControlAdapterModelDefaultSettings(modelConfig);
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
const { handleSubmit, control, formState, reset } = useForm<ControlNetOrT2IAdapterDefaultSettingsFormData>({
const { handleSubmit, control, formState, reset } = useForm<ControlAdapterModelDefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
const onSubmit = useCallback<SubmitHandler<ControlAdapterModelDefaultSettingsFormData>>(
(data) => {
const body = {
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
@@ -85,4 +85,4 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
);
});
ControlNetOrT2IAdapterDefaultSettings.displayName = 'ControlNetOrT2IAdapterDefaultSettings';
ControlAdapterModelDefaultSettings.displayName = 'ControlAdapterModelDefaultSettings';

View File

@@ -1,7 +1,7 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
import type { ControlAdapterModelDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
@@ -26,9 +26,9 @@ const OPTIONS = [
{ label: 'None', value: 'none' },
] as const;
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
type DefaultSchedulerType = ControlAdapterModelDefaultSettingsFormData['preprocessor'];
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) => {
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlAdapterModelDefaultSettingsFormData>) => {
const { t } = useTranslation();
const { field } = useController(props);

View File

@@ -1,5 +1,5 @@
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
@@ -21,7 +21,11 @@ export const ModelView = memo(({ modelConfig }: Props) => {
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
return true;
}
if (modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') {
if (
modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora'
) {
return true;
}
if (modelConfig.type === 'main' || modelConfig.type === 'lora') {
@@ -69,9 +73,9 @@ export const ModelView = memo(({ modelConfig }: Props) => {
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && (
<TriggerPhrases modelConfig={modelConfig} />
)}

View File

@@ -11,9 +11,9 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isControlLayerModelConfig,
isControlLoRAModelConfig,
isControlNetModelConfig,
isControlLayerModelConfig,
isFluxMainModelModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,

View File

@@ -145,8 +145,10 @@ export const isControlNetModelConfig = (config: AnyModelConfig): config is Contr
return config.type === 'controlnet';
};
export const isControlLayerModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => {
return config.type === 'controlnet' || config.type === "t2i_adapter" || config.type === "control_lora";
export const isControlLayerModelConfig = (
config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => {
return config.type === 'controlnet' || config.type === 't2i_adapter' || config.type === 'control_lora';
};
export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdapterModelConfig => {
@@ -207,12 +209,6 @@ export const isSpandrelImageToImageModelConfig = (
return config.type === 'spandrel_image_to_image';
};
export const isControlNetOrT2IAdapterModelConfig = (
config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig => {
return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config);
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};