mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
(ui): replace logic for controlnet/t2i to include control_loras and display default settings in model manager
This commit is contained in:
@@ -30,7 +30,7 @@ import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
@@ -190,7 +190,7 @@ const handleLoRAModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
};
|
||||
|
||||
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
|
||||
const caModels = models.filter(isControlLayerModelConfig);
|
||||
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
|
||||
const selectedControlAdapterModel = entity.controlAdapter.model;
|
||||
// `null` is a valid control adapter model - no need to do anything.
|
||||
|
||||
@@ -26,7 +26,12 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
ImageDTO,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) =>
|
||||
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
|
||||
@@ -157,8 +162,10 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
/>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
{controlAdapter.type !== 'control_lora' && <Weight weight={controlAdapter.weight} onChange={onChangeWeight} />}
|
||||
{controlAdapter.type !== 'control_lora' && (
|
||||
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
)}
|
||||
{controlAdapter.type === 'controlnet' && !isFLUX && (
|
||||
<ControlLayerControlAdapterControlMode
|
||||
controlMode={controlAdapter.controlMode}
|
||||
|
||||
@@ -5,7 +5,12 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useControlLayerModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelKey: string | null;
|
||||
|
||||
@@ -18,6 +18,7 @@ import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/se
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
@@ -26,8 +27,13 @@ import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { useCallback } from 'react';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { isControlLayerModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Selects the default control adapter configuration based on the model configurations and the base.
|
||||
@@ -39,13 +45,13 @@ import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'ser
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
(query, base): ControlNetConfig | T2IAdapterConfig => {
|
||||
(query, base): ControlNetConfig | T2IAdapterConfig | ControlLoRAConfig => {
|
||||
const { data } = query;
|
||||
let model: ControlNetModelConfig | T2IAdapterModelConfig | null = null;
|
||||
let model: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null = null;
|
||||
if (data) {
|
||||
const modelConfigs = modelConfigsAdapterSelectors
|
||||
.selectAll(data)
|
||||
.filter(isControlNetOrT2IAdapterModelConfig)
|
||||
.filter(isControlLayerModelConfig)
|
||||
.sort((a) => (a.type === 'controlnet' ? -1 : 1)); // Prefer ControlNet models
|
||||
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
|
||||
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
|
||||
|
||||
@@ -18,7 +18,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { buildSelectModelConfig } from 'services/api/hooks/modelsByType';
|
||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlLayerModelConfig } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -204,7 +204,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
|
||||
// If the parent is a control layer adapter, we should check if the model has a default filter and set it if so
|
||||
const selectModelConfig = buildSelectModelConfig(
|
||||
this.parent.state.controlAdapter.model.key,
|
||||
isControlNetOrT2IAdapterModelConfig
|
||||
isControlLayerModelConfig
|
||||
);
|
||||
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
|
||||
// This always returns a filter
|
||||
|
||||
@@ -34,7 +34,13 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { merge } from 'lodash-es';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type {
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
ImageDTO,
|
||||
IPAdapterModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type {
|
||||
@@ -486,7 +492,7 @@ export const canvasSlice = createSlice({
|
||||
) => {
|
||||
const { entityIdentifier, weight } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer || !layer.controlAdapter) {
|
||||
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
|
||||
return;
|
||||
}
|
||||
layer.controlAdapter.weight = weight;
|
||||
@@ -497,7 +503,7 @@ export const canvasSlice = createSlice({
|
||||
) => {
|
||||
const { entityIdentifier, beginEndStepPct } = action.payload;
|
||||
const layer = selectEntity(state, entityIdentifier);
|
||||
if (!layer || !layer.controlAdapter) {
|
||||
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
|
||||
return;
|
||||
}
|
||||
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
|
||||
|
||||
@@ -454,7 +454,9 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
|
||||
* Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default
|
||||
* filter for the model type will be used.
|
||||
*/
|
||||
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
|
||||
export const getFilterForModel = (
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null
|
||||
) => {
|
||||
if (!modelConfig) {
|
||||
// No model
|
||||
return null;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
|
||||
export const useControlAdapterModelDefaultSettings = (
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig
|
||||
) => {
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||
import { useControlAdapterModelDefaultSettings } from 'features/modelManagerV2/hooks/useControlAdapterModelDefaultSettings';
|
||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/DefaultPreprocessor';
|
||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -9,28 +9,28 @@ import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
||||
export type ControlAdapterModelDefaultSettingsFormData = {
|
||||
preprocessor: FormField<string>;
|
||||
};
|
||||
|
||||
type Props = {
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig;
|
||||
};
|
||||
|
||||
export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
export const ControlAdapterModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
|
||||
const defaultSettingsDefaults = useControlAdapterModelDefaultSettings(modelConfig);
|
||||
|
||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||
|
||||
const { handleSubmit, control, formState, reset } = useForm<ControlNetOrT2IAdapterDefaultSettingsFormData>({
|
||||
const { handleSubmit, control, formState, reset } = useForm<ControlAdapterModelDefaultSettingsFormData>({
|
||||
defaultValues: defaultSettingsDefaults,
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
||||
const onSubmit = useCallback<SubmitHandler<ControlAdapterModelDefaultSettingsFormData>>(
|
||||
(data) => {
|
||||
const body = {
|
||||
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
||||
@@ -85,4 +85,4 @@ export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Prop
|
||||
);
|
||||
});
|
||||
|
||||
ControlNetOrT2IAdapterDefaultSettings.displayName = 'ControlNetOrT2IAdapterDefaultSettings';
|
||||
ControlAdapterModelDefaultSettings.displayName = 'ControlAdapterModelDefaultSettings';
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
import type { ControlAdapterModelDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
|
||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -26,9 +26,9 @@ const OPTIONS = [
|
||||
{ label: 'None', value: 'none' },
|
||||
] as const;
|
||||
|
||||
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
|
||||
type DefaultSchedulerType = ControlAdapterModelDefaultSettingsFormData['preprocessor'];
|
||||
|
||||
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) => {
|
||||
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlAdapterModelDefaultSettingsFormData>) => {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
|
||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
@@ -21,7 +21,11 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
|
||||
return true;
|
||||
}
|
||||
if (modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') {
|
||||
if (
|
||||
modelConfig.type === 'controlnet' ||
|
||||
modelConfig.type === 't2i_adapter' ||
|
||||
modelConfig.type === 'control_lora'
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
if (modelConfig.type === 'main' || modelConfig.type === 'lora') {
|
||||
@@ -69,9 +73,9 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' ||
|
||||
modelConfig.type === 't2i_adapter' ||
|
||||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
|
||||
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && (
|
||||
<TriggerPhrases modelConfig={modelConfig} />
|
||||
)}
|
||||
|
||||
@@ -11,9 +11,9 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isCLIPVisionModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isControlLoRAModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
|
||||
@@ -145,8 +145,10 @@ export const isControlNetModelConfig = (config: AnyModelConfig): config is Contr
|
||||
return config.type === 'controlnet';
|
||||
};
|
||||
|
||||
export const isControlLayerModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => {
|
||||
return config.type === 'controlnet' || config.type === "t2i_adapter" || config.type === "control_lora";
|
||||
export const isControlLayerModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => {
|
||||
return config.type === 'controlnet' || config.type === 't2i_adapter' || config.type === 'control_lora';
|
||||
};
|
||||
|
||||
export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdapterModelConfig => {
|
||||
@@ -207,12 +209,6 @@ export const isSpandrelImageToImageModelConfig = (
|
||||
return config.type === 'spandrel_image_to_image';
|
||||
};
|
||||
|
||||
export const isControlNetOrT2IAdapterModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is ControlNetModelConfig | T2IAdapterModelConfig => {
|
||||
return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config);
|
||||
};
|
||||
|
||||
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base !== 'sdxl-refiner';
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user