mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-17 14:54:04 -05:00
refactor(ui): wire up CA logic across (wip)
This commit is contained in:
committed by
Kent Keirsey
parent
424a27eeda
commit
0e55488ff6
@@ -1,75 +0,0 @@
|
||||
import { Box, Flex, Icon, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ControlAdapterProcessorComponent from 'features/controlAdapters/components/ControlAdapterProcessorComponent';
|
||||
import ControlAdapterShouldAutoConfig from 'features/controlAdapters/components/ControlAdapterShouldAutoConfig';
|
||||
import ParamControlAdapterIPMethod from 'features/controlAdapters/components/parameters/ParamControlAdapterIPMethod';
|
||||
import ParamControlAdapterProcessorSelect from 'features/controlAdapters/components/parameters/ParamControlAdapterProcessorSelect';
|
||||
import { ParamControlAdapterBeginEnd } from 'features/controlLayers/components/CALayer/CALayerBeginEndStepPct';
|
||||
import ParamControlAdapterControlMode from 'features/controlLayers/components/CALayer/CALayerControlMode';
|
||||
import { CALayerImagePreview } from 'features/controlLayers/components/CALayer/CALayerImagePreview';
|
||||
import ParamControlAdapterModel from 'features/controlLayers/components/CALayer/CALayerModelCombobox';
|
||||
import ParamControlAdapterWeight from 'features/controlLayers/components/CALayer/CALayerWeight';
|
||||
import { selectCALayer } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretUpBold } from 'react-icons/pi';
|
||||
import { useToggle } from 'react-use';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const CALayerCAConfig = memo(({ layerId }: Props) => {
|
||||
const caType = useAppSelector((s) => selectCALayer(s.controlLayers.present, layerId).controlAdapter.type);
|
||||
const { t } = useTranslation();
|
||||
const [isExpanded, toggleIsExpanded] = useToggle(false);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4} position="relative" w="full">
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
|
||||
<ParamControlAdapterModel id={id} />{' '}
|
||||
</Box>
|
||||
|
||||
{controlAdapterType !== 'ip_adapter' && (
|
||||
<IconButton
|
||||
size="sm"
|
||||
tooltip={isExpanded ? t('controlnet.hideAdvanced') : t('controlnet.showAdvanced')}
|
||||
aria-label={isExpanded ? t('controlnet.hideAdvanced') : t('controlnet.showAdvanced')}
|
||||
onClick={toggleIsExpanded}
|
||||
variant="ghost"
|
||||
icon={
|
||||
<Icon
|
||||
boxSize={4}
|
||||
as={PiCaretUpBold}
|
||||
transform={isExpanded ? 'rotate(0deg)' : 'rotate(180deg)'}
|
||||
transitionProperty="common"
|
||||
transitionDuration="normal"
|
||||
/>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex gap={4} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={3} w="full">
|
||||
{controlAdapterType === 'ip_adapter' && <ParamControlAdapterIPMethod id={id} />}
|
||||
{controlAdapterType === 'controlnet' && <ParamControlAdapterControlMode id={id} />}
|
||||
<ParamControlAdapterWeight id={id} />
|
||||
<ParamControlAdapterBeginEnd id={id} />
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
|
||||
<CALayerImagePreview id={id} isSmall />
|
||||
</Flex>
|
||||
</Flex>
|
||||
{isExpanded && (
|
||||
<>
|
||||
<ControlAdapterShouldAutoConfig id={id} />
|
||||
<ParamControlAdapterProcessorSelect id={id} />
|
||||
<ControlAdapterProcessorComponent id={id} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CALayerCAConfig.displayName = 'CALayerCAConfig';
|
||||
@@ -1,29 +1,15 @@
|
||||
import { Flex, Spacer, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ControlAdapterLayerConfig from 'features/controlLayers/components/CALayer/ControlAdapterLayerConfig';
|
||||
import { IPALayerConfig } from 'features/controlLayers/components/IPALayer/IPALayerConfig';
|
||||
import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton';
|
||||
import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle';
|
||||
import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle';
|
||||
import { isIPAdapterLayer, selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const IPALayer = memo(({ layerId }: Props) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectControlLayersSlice, (controlLayers) => {
|
||||
const layer = controlLayers.present.layers.find((l) => l.id === layerId);
|
||||
assert(isIPAdapterLayer(layer), `Layer ${layerId} not found or not an IP Adapter layer`);
|
||||
return layer.ipAdapterId;
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
const ipAdapterId = useAppSelector(selector);
|
||||
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
|
||||
return (
|
||||
<Flex gap={2} bg="base.800" borderRadius="base" p="1px" px={2}>
|
||||
@@ -36,7 +22,7 @@ export const IPALayer = memo(({ layerId }: Props) => {
|
||||
</Flex>
|
||||
{isOpen && (
|
||||
<Flex flexDir="column" gap={3} px={3} pb={3}>
|
||||
<ControlAdapterLayerConfig id={ipAdapterId} />
|
||||
<IPALayerConfig layerId={layerId} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ControlAdapterBeginEndStepPct } from 'features/controlLayers/components/CALayer/ControlAdapterBeginEndStepPct';
|
||||
import { ControlAdapterWeight } from 'features/controlLayers/components/CALayer/ControlAdapterWeight';
|
||||
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPALayer/IPAdapterImagePreview';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPALayer/IPAdapterMethod';
|
||||
import { IPAdapterModelCombobox } from 'features/controlLayers/components/IPALayer/IPALayerModelCombobox';
|
||||
import {
|
||||
caOrIPALayerBeginEndStepPctChanged,
|
||||
caOrIPALayerWeightChanged,
|
||||
ipaLayerCLIPVisionModelChanged,
|
||||
ipaLayerImageChanged,
|
||||
ipaLayerMethodChanged,
|
||||
ipaLayerModelChanged,
|
||||
selectIPALayer,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { CLIPVisionModel, IPMethod } from 'features/controlLayers/util/controlAdapters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
layerId: string;
|
||||
};
|
||||
|
||||
export const IPALayerConfig = memo(({ layerId }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const ipAdapter = useAppSelector((s) => selectIPALayer(s.controlLayers.present, layerId).ipAdapter);
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
dispatch(
|
||||
caOrIPALayerBeginEndStepPctChanged({
|
||||
layerId,
|
||||
beginEndStepPct,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
const onChangeWeight = useCallback(
|
||||
(weight: number) => {
|
||||
dispatch(caOrIPALayerWeightChanged({ layerId, weight }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
const onChangeIPMethod = useCallback(
|
||||
(method: IPMethod) => {
|
||||
dispatch(ipaLayerMethodChanged({ layerId, method }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig) => {
|
||||
dispatch(ipaLayerModelChanged({ layerId, modelConfig }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
const onChangeCLIPVisionModel = useCallback(
|
||||
(clipVisionModel: CLIPVisionModel) => {
|
||||
dispatch(ipaLayerCLIPVisionModelChanged({ layerId, clipVisionModel }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
const onChangeImage = useCallback(
|
||||
(imageDTO: ImageDTO | null) => {
|
||||
dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
|
||||
},
|
||||
[dispatch, layerId]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4} position="relative" w="full">
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
|
||||
<IPAdapterModelCombobox
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
clipVisionModel={ipAdapter.clipVisionModel}
|
||||
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
<Flex gap={4} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={3} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
<ControlAdapterWeight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<ControlAdapterBeginEndStepPct
|
||||
beginEndStepPct={ipAdapter.beginEndStepPct}
|
||||
onChange={onChangeBeginEndStepPct}
|
||||
/>
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
|
||||
<IPAdapterImagePreview image={ipAdapter.image} onChangeImage={onChangeImage} layerId={layerId} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
IPALayerConfig.displayName = 'IPALayerConfig';
|
||||
@@ -0,0 +1,100 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import type { CLIPVisionModel } from 'features/controlLayers/util/controlAdapters';
|
||||
import { isCLIPVisionModel } from 'features/controlLayers/util/controlAdapters';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const CLIP_VISION_OPTIONS = [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
];
|
||||
|
||||
type Props = {
|
||||
modelKey: string | null;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig) => void;
|
||||
clipVisionModel: CLIPVisionModel;
|
||||
onChangeCLIPVisionModel: (clipVisionModel: CLIPVisionModel) => void;
|
||||
};
|
||||
|
||||
export const IPAdapterModelCombobox = memo(
|
||||
({ modelKey, onChangeModel, clipVisionModel, onChangeCLIPVisionModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
onChangeModel(modelConfig);
|
||||
},
|
||||
[onChangeModel]
|
||||
);
|
||||
|
||||
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
assert(isCLIPVisionModel(v?.value));
|
||||
onChangeCLIPVisionModel(v.value);
|
||||
},
|
||||
[onChangeCLIPVisionModel]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChangeModel,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionModelValue = useMemo(
|
||||
() => CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel),
|
||||
[clipVisionModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={4}>
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
{selectedModel?.format === 'checkpoint' && (
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
|
||||
<Combobox
|
||||
options={CLIP_VISION_OPTIONS}
|
||||
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||
value={clipVisionModelValue}
|
||||
onChange={_onChangeCLIPVisionModel}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IPAdapterModelCombobox.displayName = 'IPALayerModelCombobox';
|
||||
@@ -0,0 +1,119 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { ImageWithDims } from 'features/controlLayers/util/controlAdapters';
|
||||
import type { ControlLayerDropData, ImageDraggableData } from 'features/dnd/types';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold, PiRulerBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { ControlLayerAction, ImageDTO } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
image: ImageWithDims | null;
|
||||
onChangeImage: (imageDTO: ImageDTO | null) => void;
|
||||
layerId: string; // required for the dnd/upload interactions
|
||||
};
|
||||
|
||||
export const IPAdapterImagePreview = memo(({ image, onChangeImage, layerId }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isConnected = useAppSelector((s) => s.system.isConnected);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const shift = useShiftModifier();
|
||||
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
|
||||
image?.imageName ?? skipToken
|
||||
);
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
onChangeImage(null);
|
||||
}, [onChangeImage]);
|
||||
|
||||
const handleSetControlImageToDimensions = useCallback(() => {
|
||||
if (!controlImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setBoundingBoxDimensions({ width: controlImage.width, height: controlImage.height }, optimalDimension));
|
||||
} else {
|
||||
if (shift) {
|
||||
const { width, height } = controlImage;
|
||||
dispatch(widthChanged({ width, updateAspectRatio: true }));
|
||||
dispatch(heightChanged({ height, updateAspectRatio: true }));
|
||||
} else {
|
||||
const { width, height } = calculateNewSize(
|
||||
controlImage.width / controlImage.height,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
dispatch(widthChanged({ width, updateAspectRatio: true }));
|
||||
dispatch(heightChanged({ height, updateAspectRatio: true }));
|
||||
}
|
||||
}
|
||||
}, [controlImage, activeTabName, dispatch, optimalDimension, shift]);
|
||||
|
||||
const draggableData = useMemo<ImageDraggableData | undefined>(() => {
|
||||
if (controlImage) {
|
||||
return {
|
||||
id: layerId,
|
||||
payloadType: 'IMAGE_DTO',
|
||||
payload: { imageDTO: controlImage },
|
||||
};
|
||||
}
|
||||
}, [controlImage, layerId]);
|
||||
|
||||
const droppableData = useMemo<ControlLayerDropData>(
|
||||
() => ({
|
||||
id: layerId,
|
||||
actionType: 'SET_CONTROL_LAYER_IMAGE',
|
||||
context: { layerId },
|
||||
}),
|
||||
[layerId]
|
||||
);
|
||||
|
||||
const postUploadAction = useMemo<ControlLayerAction>(() => ({ type: 'SET_CONTROL_LAYER_IMAGE', layerId }), [layerId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isConnected && isErrorControlImage) {
|
||||
handleResetControlImage();
|
||||
}
|
||||
}, [handleResetControlImage, isConnected, isErrorControlImage]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
/>
|
||||
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.resetControlImage')}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
@@ -0,0 +1,44 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import type { IPMethod } from 'features/controlLayers/util/controlAdapters';
|
||||
import { isIPMethod } from 'features/controlLayers/util/controlAdapters';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
method: IPMethod;
|
||||
onChange: (method: IPMethod) => void;
|
||||
};
|
||||
|
||||
export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const options: { label: string; value: IPMethod }[] = useMemo(
|
||||
() => [
|
||||
{ label: t('controlnet.full'), value: 'full' },
|
||||
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
|
||||
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
assert(isIPMethod(v?.value));
|
||||
onChange(v.value);
|
||||
},
|
||||
[onChange]
|
||||
);
|
||||
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<InformationalPopover feature="ipAdapterMethod">
|
||||
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<Combobox value={value} options={options} onChange={_onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
IPAdapterMethod.displayName = 'IPAdapterMethod';
|
||||
@@ -1,136 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import {
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterModelChanged,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
ControlNetModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
type ParamControlAdapterModelProps = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||
|
||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const isEnabled = useControlAdapterIsEnabled(id);
|
||||
const controlAdapterType = useControlAdapterType(id);
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
|
||||
const mainModel = useAppSelector(selectMainModel);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
controlAdapterModelChanged({
|
||||
id,
|
||||
modelConfig,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v?.value) {
|
||||
return;
|
||||
}
|
||||
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
|
||||
},
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, modelConfig]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionOptions = useMemo<ComboboxOption[]>(
|
||||
() => [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
],
|
||||
[]
|
||||
);
|
||||
|
||||
const clipVisionModel = useMemo(
|
||||
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
|
||||
[clipVisionOptions, currentCLIPVisionModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={4}>
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
|
||||
<FormControl
|
||||
isDisabled={!isEnabled}
|
||||
isInvalid={!value || mainModel?.base !== modelConfig?.base}
|
||||
width="max-content"
|
||||
minWidth={28}
|
||||
>
|
||||
<Combobox
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||
value={clipVisionModel}
|
||||
onChange={onCLIPVisionModelChange}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlAdapterModel);
|
||||
Reference in New Issue
Block a user