feat(ui): rework control adapter/ip adapter creation handling

- Add selectors to get the default control adapter and ip adapter with model, preferring controlnet over t2i adapter for model
- Add hooks to add each entity type, using the defaults
- Add hooks to add prompts/ip adapters to a regional guidance layer
- Use the defaults in other places where we add control layers or ip adapters (e.g. dnd-triggered entity creation)
This commit is contained in:
psychedelicious
2024-09-07 19:14:11 +10:00
parent 864e471e5a
commit 464603e0ea
12 changed files with 253 additions and 237 deletions

View File

@@ -1,6 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import {
controlLayerAdded,
ipaImageChanged,
@@ -103,11 +104,14 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const state = getState();
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const { x, y } = selectCanvasSlice(state).bbox.rect;
const defaultControlAdapter = selectDefaultControlAdapter(state);
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
controlAdapter: defaultControlAdapter,
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
return;

View File

@@ -1,34 +1,22 @@
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import {
controlLayerAdded,
inpaintMaskAdded,
ipaAdded,
rasterLayerAdded,
rgAdded,
} from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
useAddControlLayer,
useAddInpaintMask,
useAddIPAdapter,
useAddRasterLayer,
useAddRegionalGuidance,
} from 'features/controlLayers/hooks/addLayerHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
export const CanvasAddEntityButtons = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const addInpaintMask = useCallback(() => {
dispatch(inpaintMaskAdded({ isSelected: true }));
}, [dispatch]);
const addRegionalGuidance = useCallback(() => {
dispatch(rgAdded({ isSelected: true }));
}, [dispatch]);
const addRasterLayer = useCallback(() => {
dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]);
const addControlLayer = useCallback(() => {
dispatch(controlLayerAdded({ isSelected: true }));
}, [dispatch]);
const addIPAdapter = useCallback(() => {
dispatch(ipaAdded({ isSelected: true }));
}, [dispatch]);
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addIPAdapter = useAddIPAdapter();
return (
<Flex flexDir="column" w="full" h="full" alignItems="center">

View File

@@ -1,37 +1,22 @@
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useDefaultIPAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import {
controlLayerAdded,
inpaintMaskAdded,
ipaAdded,
rasterLayerAdded,
rgAdded,
} from 'features/controlLayers/store/canvasSlice';
import { memo, useCallback } from 'react';
useAddControlLayer,
useAddInpaintMask,
useAddIPAdapter,
useAddRasterLayer,
useAddRegionalGuidance,
} from 'features/controlLayers/hooks/addLayerHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const defaultIPAdapter = useDefaultIPAdapter();
const addInpaintMask = useCallback(() => {
dispatch(inpaintMaskAdded({ isSelected: true }));
}, [dispatch]);
const addRegionalGuidance = useCallback(() => {
dispatch(rgAdded({ isSelected: true }));
}, [dispatch]);
const addRasterLayer = useCallback(() => {
dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]);
const addControlLayer = useCallback(() => {
dispatch(controlLayerAdded({ isSelected: true }));
}, [dispatch]);
const addIPAdapter = useCallback(() => {
const overrides = { ipAdapter: defaultIPAdapter };
dispatch(ipaAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addIPAdapter = useAddIPAdapter();
return (
<Menu>

View File

@@ -1,21 +1,35 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useControlLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import {
controlLayerBeginEndStepPctChanged,
controlLayerControlModeChanged,
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import type { ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(
() =>
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
const layer = selectEntityOrThrow(canvas, entityIdentifier);
return layer.controlAdapter;
}),
[entityIdentifier]
);
const controlAdapter = useAppSelector(selectControlAdapter);
return controlAdapter;
};
export const ControlLayerControlAdapter = memo(() => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('control_layer');

View File

@@ -1,10 +1,7 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { IMAGE_FILTERS, isFilterableEntityIdentifier, isFilterType } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
@@ -17,8 +14,6 @@ type Props = {
export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const canvasManager = useCanvasManager();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
@@ -29,34 +24,8 @@ export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onCha
return;
}
onChangeModel(modelConfig);
// When we set the model for the first time, we'll set the default filter settings and open the filter popup
if (modelKey) {
// If there is already a model key, this is not the first time we're setting the model
return;
}
// Open the filter popup by setting this entity as the filtering entity
if (!canvasManager.stateApi.$isFiltering.get()) {
if (!isFilterableEntityIdentifier(entityIdentifier)) {
return;
}
const adapter = canvasManager.getAdapter(entityIdentifier);
if (!adapter) {
return;
}
// Update the filter, preferring the model's default
const filterConfig = isFilterType(modelConfig.default_settings?.preprocessor)
? IMAGE_FILTERS[modelConfig.default_settings.preprocessor].buildDefaults(modelConfig.base)
: IMAGE_FILTERS.canny_image_processor.buildDefaults(modelConfig.base);
adapter.filterer.startFilter(filterConfig);
adapter.filterer.previewFilter();
}
},
[canvasManager, entityIdentifier, modelKey, onChangeModel]
[onChangeModel]
);
const getIsDisabled = useCallback(

View File

@@ -1,42 +1,28 @@
import { Button, Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import {
rgIPAdapterAdded,
rgNegativePromptChanged,
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import { useCallback, useMemo } from 'react';
buildSelectValidRegionalGuidanceActions,
useAddRegionalGuidanceIPAdapter,
useAddRegionalGuidanceNegativePrompt,
useAddRegionalGuidancePositivePrompt,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const addRegionalGuidanceIPAdapter = useAddRegionalGuidanceIPAdapter(entityIdentifier);
const addRegionalGuidancePositivePrompt = useAddRegionalGuidancePositivePrompt(entityIdentifier);
const addRegionalGuidanceNegativePrompt = useAddRegionalGuidanceNegativePrompt(entityIdentifier);
const selectValidActions = useMemo(
() =>
createMemoizedSelector(selectCanvasSlice, (canvas) => {
const entity = selectEntityOrThrow(canvas, entityIdentifier);
return {
canAddPositivePrompt: entity?.positivePrompt === null,
canAddNegativePrompt: entity?.negativePrompt === null,
};
}),
() => buildSelectValidRegionalGuidanceActions(entityIdentifier),
[entityIdentifier]
);
const validActions = useAppSelector(selectValidActions);
const addPositivePrompt = useCallback(() => {
dispatch(rgPositivePromptChanged({ entityIdentifier, prompt: '' }));
}, [dispatch, entityIdentifier]);
const addNegativePrompt = useCallback(() => {
dispatch(rgNegativePromptChanged({ entityIdentifier, prompt: '' }));
}, [dispatch, entityIdentifier]);
const addIPAdapter = useCallback(() => {
dispatch(rgIPAdapterAdded({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<Flex w="full" p={2} justifyContent="space-between">
@@ -44,7 +30,7 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
size="sm"
variant="ghost"
leftIcon={<PiPlusBold />}
onClick={addPositivePrompt}
onClick={addRegionalGuidancePositivePrompt}
isDisabled={!validActions.canAddPositivePrompt}
>
{t('common.positivePrompt')}
@@ -53,12 +39,12 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
size="sm"
variant="ghost"
leftIcon={<PiPlusBold />}
onClick={addNegativePrompt}
onClick={addRegionalGuidanceNegativePrompt}
isDisabled={!validActions.canAddNegativePrompt}
>
{t('common.negativePrompt')}
</Button>
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addRegionalGuidanceIPAdapter}>
{t('common.ipAdapter')}
</Button>
</Flex>

View File

@@ -1,53 +1,38 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
rgIPAdapterAdded,
rgNegativePromptChanged,
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import { memo, useCallback, useMemo } from 'react';
buildSelectValidRegionalGuidanceActions,
useAddRegionalGuidanceIPAdapter,
useAddRegionalGuidanceNegativePrompt,
useAddRegionalGuidancePositivePrompt,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const RegionalGuidanceMenuItemsAddPromptsAndIPAdapter = memo(() => {
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const addRegionalGuidanceIPAdapter = useAddRegionalGuidanceIPAdapter(entityIdentifier);
const addRegionalGuidancePositivePrompt = useAddRegionalGuidancePositivePrompt(entityIdentifier);
const addRegionalGuidanceNegativePrompt = useAddRegionalGuidanceNegativePrompt(entityIdentifier);
const selectValidActions = useMemo(
() =>
createMemoizedSelector(selectCanvasSlice, (canvas) => {
const entity = selectEntity(canvas, entityIdentifier);
return {
canAddPositivePrompt: entity?.positivePrompt === null,
canAddNegativePrompt: entity?.negativePrompt === null,
};
}),
() => buildSelectValidRegionalGuidanceActions(entityIdentifier),
[entityIdentifier]
);
const validActions = useAppSelector(selectValidActions);
const addPositivePrompt = useCallback(() => {
dispatch(rgPositivePromptChanged({ entityIdentifier, prompt: '' }));
}, [dispatch, entityIdentifier]);
const addNegativePrompt = useCallback(() => {
dispatch(rgNegativePromptChanged({ entityIdentifier, prompt: '' }));
}, [dispatch, entityIdentifier]);
const addIPAdapter = useCallback(() => {
dispatch(rgIPAdapterAdded({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
return (
<>
<MenuItem onClick={addPositivePrompt} isDisabled={!validActions.canAddPositivePrompt || isBusy}>
<MenuItem onClick={addRegionalGuidancePositivePrompt} isDisabled={!validActions.canAddPositivePrompt || isBusy}>
{t('controlLayers.addPositivePrompt')}
</MenuItem>
<MenuItem onClick={addNegativePrompt} isDisabled={!validActions.canAddNegativePrompt || isBusy}>
<MenuItem onClick={addRegionalGuidanceNegativePrompt} isDisabled={!validActions.canAddNegativePrompt || isBusy}>
{t('controlLayers.addNegativePrompt')}
</MenuItem>
<MenuItem onClick={addIPAdapter} isDisabled={isBusy}>
<MenuItem onClick={addRegionalGuidanceIPAdapter} isDisabled={isBusy}>
{t('controlLayers.addIPAdapter')}
</MenuItem>
</>

View File

@@ -1,12 +1,11 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import {
controlLayerAdded,
inpaintMaskAdded,
ipaAdded,
rasterLayerAdded,
rgAdded,
} from 'features/controlLayers/store/canvasSlice';
useAddControlLayer,
useAddInpaintMask,
useAddIPAdapter,
useAddRasterLayer,
useAddRegionalGuidance,
} from 'features/controlLayers/hooks/addLayerHooks';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -18,26 +17,31 @@ type Props = {
export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const addInpaintMask = useAddInpaintMask();
const addRegionalGuidance = useAddRegionalGuidance();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const addIPAdapter = useAddIPAdapter();
const onClick = useCallback(() => {
switch (type) {
case 'inpaint_mask':
dispatch(inpaintMaskAdded({ isSelected: true }));
addInpaintMask();
break;
case 'regional_guidance':
dispatch(rgAdded({ isSelected: true }));
addRegionalGuidance();
break;
case 'raster_layer':
dispatch(rasterLayerAdded({ isSelected: true }));
addRasterLayer();
break;
case 'control_layer':
dispatch(controlLayerAdded({ isSelected: true }));
addControlLayer();
break;
case 'ip_adapter':
dispatch(ipaAdded({ isSelected: true }));
addIPAdapter();
break;
}
}, [dispatch, type]);
}, [addControlLayer, addIPAdapter, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
const label = useMemo(() => {
switch (type) {

View File

@@ -0,0 +1,154 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import {
controlLayerAdded,
inpaintMaskAdded,
ipaAdded,
rasterLayerAdded,
rgAdded,
rgIPAdapterAdded,
rgNegativePromptChanged,
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
(query, base): ControlNetConfig | T2IAdapterConfig => {
const { data } = query;
let model: ControlNetModelConfig | T2IAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors
.selectAll(data)
.filter(isControlNetOrT2IAdapterModelConfig)
.sort((a) => (a.type === 'controlnet' ? -1 : 1)); // Prefer ControlNet models
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}
const controlAdapter = model?.type === 't2i_adapter' ? deepClone(initialT2IAdapter) : deepClone(initialControlNet);
if (model) {
controlAdapter.model = zModelIdentifierField.parse(model);
}
return controlAdapter;
}
);
const selectDefaultIPAdapter = createSelector(selectModelConfigsQuery, selectBase, (query, base): IPAdapterConfig => {
const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
}
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
}
return ipAdapter;
});
export const useAddControlLayer = () => {
const dispatch = useAppDispatch();
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const addControlLayer = useCallback(() => {
const overrides = { controlAdapter: defaultControlAdapter };
dispatch(controlLayerAdded({ isSelected: true, overrides }));
}, [defaultControlAdapter, dispatch]);
return addControlLayer;
};
export const useAddRasterLayer = () => {
const dispatch = useAppDispatch();
const addRasterLayer = useCallback(() => {
dispatch(rasterLayerAdded({ isSelected: true }));
}, [dispatch]);
return addRasterLayer;
};
export const useAddInpaintMask = () => {
const dispatch = useAppDispatch();
const addInpaintMask = useCallback(() => {
dispatch(inpaintMaskAdded({ isSelected: true }));
}, [dispatch]);
return addInpaintMask;
};
export const useAddRegionalGuidance = () => {
const dispatch = useAppDispatch();
const addRegionalGuidance = useCallback(() => {
dispatch(rgAdded({ isSelected: true }));
}, [dispatch]);
return addRegionalGuidance;
};
export const useAddIPAdapter = () => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const addControlLayer = useCallback(() => {
const overrides = { ipAdapter: defaultIPAdapter };
dispatch(ipaAdded({ isSelected: true, overrides }));
}, [defaultIPAdapter, dispatch]);
return addControlLayer;
};
export const useAddRegionalGuidanceIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>) => {
const dispatch = useAppDispatch();
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
const addRegionalGuidanceIPAdapter = useCallback(() => {
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: defaultIPAdapter }));
}, [defaultIPAdapter, dispatch, entityIdentifier]);
return addRegionalGuidanceIPAdapter;
};
export const useAddRegionalGuidancePositivePrompt = (entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>) => {
const dispatch = useAppDispatch();
const addRegionalGuidancePositivePrompt = useCallback(() => {
dispatch(rgPositivePromptChanged({ entityIdentifier, prompt: '' }));
}, [dispatch, entityIdentifier]);
return addRegionalGuidancePositivePrompt;
};
export const useAddRegionalGuidanceNegativePrompt = (entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>) => {
const dispatch = useAppDispatch();
const addRegionalGuidanceNegativePrompt = useCallback(() => {
dispatch(rgNegativePromptChanged({ entityIdentifier, prompt: '' }));
}, [dispatch, entityIdentifier]);
return addRegionalGuidanceNegativePrompt;
};
export const buildSelectValidRegionalGuidanceActions = (
entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>
) => {
return createMemoizedSelector(selectCanvasSlice, (canvas) => {
const entity = selectEntityOrThrow(canvas, entityIdentifier);
return {
canAddPositivePrompt: entity?.positivePrompt === null,
canAddNegativePrompt: entity?.negativePrompt === null,
};
});
};

View File

@@ -1,69 +0,0 @@
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useMemo } from 'react';
import { useControlNetAndT2IAdapterModels, useIPAdapterModels } from 'services/api/hooks/modelsByType';
export const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
const selectControlAdapter = useMemo(
() =>
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
const layer = selectEntityOrThrow(canvas, entityIdentifier);
return layer.controlAdapter;
}),
[entityIdentifier]
);
const controlAdapter = useAppSelector(selectControlAdapter);
return controlAdapter;
};
/** @knipignore */
export const useDefaultControlAdapter = (): ControlNetConfig | T2IAdapterConfig => {
const [modelConfigs] = useControlNetAndT2IAdapterModels();
const baseModel = useAppSelector(selectBase);
const defaultControlAdapter = useMemo(() => {
const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true));
const model = compatibleModels[0] ?? modelConfigs[0] ?? null;
const controlAdapter = model?.type === 't2i_adapter' ? deepClone(initialT2IAdapter) : deepClone(initialControlNet);
if (model) {
controlAdapter.model = zModelIdentifierField.parse(model);
}
return controlAdapter;
}, [baseModel, modelConfigs]);
return defaultControlAdapter;
};
export const useDefaultIPAdapter = (): IPAdapterConfig => {
const [modelConfigs] = useIPAdapterModels();
const baseModel = useAppSelector(selectBase);
const defaultControlAdapter = useMemo(() => {
const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true));
const model = compatibleModels[0] ?? modelConfigs[0] ?? null;
const ipAdapter = deepClone(initialIPAdapter);
if (model) {
ipAdapter.model = zModelIdentifierField.parse(model);
}
return ipAdapter;
}, [baseModel, modelConfigs]);
return defaultControlAdapter;
};

View File

@@ -553,14 +553,8 @@ export type FillStyle = z.infer<typeof zFillStyle>;
export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success;
const zFill = z.object({ style: zFillStyle, color: zRgbColor });
const zRegionalGuidanceIPAdapterConfig = z.object({
const zRegionalGuidanceIPAdapterConfig = zIPAdapterConfig.extend({
id: zId,
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
clipVisionModel: zCLIPVisionModelV2,
});
export type RegionalGuidanceIPAdapterConfig = z.infer<typeof zRegionalGuidanceIPAdapterConfig>;

View File

@@ -278,3 +278,5 @@ export const {
usePruneCompletedModelInstallsMutation,
useGetStarterModelsQuery,
} = modelsApi;
export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();