mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support for FLUX Redux in canvas
User facing: When a FLUX main model is selected, users may now add Regional Reference Image layers. When switching between FLUX Redux and FLUX IP Adapter, the settings will change to match the model type. (IP Adapter has weight, begin/end step, but Redux does not.) The image will be retained when switching between the two. Otherwise it works the same way as IP Adapter - both in Global and Regional Reference Image layers. --- Internal state handling: Slightly awkward, but it was easiest to make FLUX Redux a second type of IP Adapter in redux state. Global and regional reference images still have a single `ipAdapter` field, but it can have a type of `ip_adapter` or `flux_redux`. Ideally, this field is called `config` or `settings` or something, but we are past that point. We _could_ do a migration to rename it, but I don't think it's worth the effort. --- Other changes: - Updated canvas layer validators to handle FLUX Redux. - Updated model list loading logic to un-set FLUX Redux models in Canvas if they are not in the list (e.g. if the user deletes the model in the main app). - Updated graph builders - new `addFLUXRedux` util & updated `addRegions` util. - Updated the `buildModelsHook` util to return a hook that accepts a filter callback. This handles a discrepancy: FLUX IP Adapter does not support regional guidance, but FLUX Redux does. The Regional Guidance settings provide the filter to filter out FLUX IP Adapter models from the combined list of IP Adapter ahd Redux models.
This commit is contained in:
@@ -31,6 +31,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
@@ -77,6 +78,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleT5EncoderModels(models, state, dispatch, log);
|
||||
handleCLIPEmbedModels(models, state, dispatch, log);
|
||||
handleFLUXVAEModels(models, state, dispatch, log);
|
||||
handleFLUXReduxModels(models, state, dispatch, log);
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -209,6 +211,10 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
|
||||
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const ipaModels = models.filter(isIPAdapterModelConfig);
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedIPAdapterModel = entity.ipAdapter.model;
|
||||
// `null` is a valid IP adapter model - no need to do anything.
|
||||
if (!selectedIPAdapterModel) {
|
||||
@@ -224,6 +230,10 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
|
||||
if (ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedIPAdapterModel = ipAdapter.model;
|
||||
// `null` is a valid IP adapter model - no need to do anything.
|
||||
if (!selectedIPAdapterModel) {
|
||||
@@ -241,6 +251,49 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
});
|
||||
};
|
||||
|
||||
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
|
||||
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.type !== 'flux_redux') {
|
||||
return;
|
||||
}
|
||||
const selectedFLUXReduxModel = entity.ipAdapter.model;
|
||||
// `null` is a valid FLUX Redux model - no need to do anything.
|
||||
if (!selectedFLUXReduxModel) {
|
||||
return;
|
||||
}
|
||||
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||
});
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
|
||||
if (ipAdapter.type !== 'flux_redux') {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedFLUXReduxModel = ipAdapter.model;
|
||||
// `null` is a valid FLUX Redux model - no need to do anything.
|
||||
if (!selectedFLUXReduxModel) {
|
||||
return;
|
||||
}
|
||||
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
}
|
||||
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
|
||||
dispatch(
|
||||
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
|
||||
);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const handlePostProcessingModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedPostProcessingModel = state.upscale.postProcessingModel;
|
||||
const allSpandrelModels = models.filter(isSpandrelImageToImageModelConfig);
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
useAddRegionalGuidance,
|
||||
useAddRegionalReferenceImage,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
@@ -22,7 +22,6 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addGlobalReferenceImage = useAddGlobalReferenceImage();
|
||||
const addRegionalReferenceImage = useAddRegionalReferenceImage();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
return (
|
||||
@@ -75,7 +74,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalReferenceImage}
|
||||
isDisabled={isFLUX || isSD3}
|
||||
isDisabled={isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</Button>
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
useAddRegionalReferenceImage,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
@@ -23,7 +23,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
const addRegionalReferenceImage = useAddRegionalReferenceImage();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
return (
|
||||
@@ -52,7 +51,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isSD3}>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
|
||||
const FLUX_CLIP_VISION = 'ViT-L';
|
||||
|
||||
const CLIP_VISION_OPTIONS = [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
|
||||
];
|
||||
|
||||
type Props = {
|
||||
model: CLIPVisionModelV2;
|
||||
onChange: (clipVisionModel: CLIPVisionModelV2) => void;
|
||||
};
|
||||
|
||||
export const CLIPVisionModel = memo(({ model, onChange }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
assert(isCLIPVisionModelV2(v?.value));
|
||||
onChange(v.value);
|
||||
},
|
||||
[onChange]
|
||||
);
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const clipVisionOptions = useMemo(() => {
|
||||
return CLIP_VISION_OPTIONS.map((option) => ({
|
||||
...option,
|
||||
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
|
||||
}));
|
||||
}, [isFLUX]);
|
||||
|
||||
const clipVisionModelValue = useMemo(() => {
|
||||
return CLIP_VISION_OPTIONS.find((o) => o.value === model);
|
||||
}, [model]);
|
||||
|
||||
return (
|
||||
<FormControl width="max-content" minWidth={28}>
|
||||
<Combobox
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={clipVisionModelValue}
|
||||
onChange={_onChangeCLIPVisionModel}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CLIPVisionModel.displayName = 'CLIPVisionModel';
|
||||
@@ -1,40 +1,36 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
|
||||
const FLUX_CLIP_VISION = 'ViT-L';
|
||||
|
||||
const CLIP_VISION_OPTIONS = [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
|
||||
];
|
||||
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
isRegionalGuidance: boolean;
|
||||
modelKey: string | null;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig) => void;
|
||||
clipVisionModel: CLIPVisionModelV2;
|
||||
onChangeCLIPVisionModel: (clipVisionModel: CLIPVisionModelV2) => void;
|
||||
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
|
||||
};
|
||||
|
||||
export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel, onChangeCLIPVisionModel }: Props) => {
|
||||
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||
const filter = useCallback(
|
||||
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
|
||||
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[isRegionalGuidance]
|
||||
);
|
||||
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig | null) => {
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
@@ -43,21 +39,11 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
[onChangeModel]
|
||||
);
|
||||
|
||||
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
assert(isCLIPVisionModelV2(v?.value));
|
||||
onChangeCLIPVisionModel(v.value);
|
||||
},
|
||||
[onChangeCLIPVisionModel]
|
||||
);
|
||||
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
(model: AnyModelConfig): boolean => {
|
||||
const isCompatible = currentBaseModel === model.base;
|
||||
const hasMainModel = Boolean(currentBaseModel);
|
||||
return !hasMainModel || !isCompatible;
|
||||
const hasSameBase = currentBaseModel === model.base;
|
||||
return !hasMainModel || !hasSameBase;
|
||||
},
|
||||
[currentBaseModel]
|
||||
);
|
||||
@@ -70,41 +56,18 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionOptions = useMemo(() => {
|
||||
return CLIP_VISION_OPTIONS.map((option) => ({
|
||||
...option,
|
||||
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
|
||||
}));
|
||||
}, [isFLUX]);
|
||||
|
||||
const clipVisionModelValue = useMemo(() => {
|
||||
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
|
||||
}, [clipVisionModel]);
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
{selectedModel?.format === 'checkpoint' && (
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
|
||||
<Combobox
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={clipVisionModelValue}
|
||||
onChange={_onChangeCLIPVisionModel}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
<Tooltip label={selectedModel?.description}>
|
||||
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('common.placeholderSelectAModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Box, Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
|
||||
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -25,7 +26,7 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
|
||||
import { IPAdapterModel } from './IPAdapterModel';
|
||||
@@ -65,7 +66,7 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig) => {
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
|
||||
},
|
||||
[dispatch, entityIdentifier]
|
||||
@@ -98,14 +99,14 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
|
||||
<IPAdapterModel
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
clipVisionModel={ipAdapter.clipVisionModel}
|
||||
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
|
||||
/>
|
||||
</Box>
|
||||
<IPAdapterModel
|
||||
isRegionalGuidance={false}
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
/>
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
|
||||
)}
|
||||
<IconButton
|
||||
onClick={pullBboxIntoIPAdapter}
|
||||
isDisabled={isBusy}
|
||||
@@ -116,12 +117,14 @@ const IPAdapterSettingsContent = memo(() => {
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} w="full" alignItems="center">
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
)}
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
|
||||
<IPAdapterImagePreview
|
||||
image={ipAdapter.image}
|
||||
onChangeImage={onChangeImage}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Box, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
|
||||
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
|
||||
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
|
||||
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
|
||||
@@ -26,7 +27,7 @@ import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type Props = {
|
||||
@@ -73,7 +74,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: IPAdapterModelConfig) => {
|
||||
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
|
||||
dispatch(rgIPAdapterModelChanged({ entityIdentifier, referenceImageId, modelConfig }));
|
||||
},
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
@@ -125,14 +126,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
</Flex>
|
||||
<Flex flexDir="column" gap={2} position="relative" w="full">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
|
||||
<IPAdapterModel
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
clipVisionModel={ipAdapter.clipVisionModel}
|
||||
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
|
||||
/>
|
||||
</Box>
|
||||
<IPAdapterModel
|
||||
isRegionalGuidance={true}
|
||||
modelKey={ipAdapter.model?.key ?? null}
|
||||
onChangeModel={onChangeModel}
|
||||
/>
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
|
||||
)}
|
||||
<IconButton
|
||||
onClick={pullBboxIntoIPAdapter}
|
||||
isDisabled={isBusy}
|
||||
@@ -143,12 +144,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
|
||||
/>
|
||||
</Flex>
|
||||
<Flex gap={2} w="full">
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
|
||||
{ipAdapter.type === 'ip_adapter' && (
|
||||
<Flex flexDir="column" gap={2} w="full">
|
||||
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
|
||||
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
</Flex>
|
||||
)}
|
||||
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
|
||||
<IPAdapterImagePreview
|
||||
image={ipAdapter.image}
|
||||
onChangeImage={onChangeImage}
|
||||
|
||||
@@ -38,6 +38,7 @@ import type { UndoableOptions } from 'redux-undo';
|
||||
import type {
|
||||
ControlLoRAModelConfig,
|
||||
ControlNetModelConfig,
|
||||
FLUXReduxModelConfig,
|
||||
ImageDTO,
|
||||
IPAdapterModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
@@ -76,6 +77,7 @@ import {
|
||||
imageDTOToImageWithDims,
|
||||
initialControlLoRA,
|
||||
initialControlNet,
|
||||
initialFLUXRedux,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from './util';
|
||||
@@ -619,11 +621,16 @@ export const canvasSlice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
if (entity.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
entity.ipAdapter.method = method;
|
||||
},
|
||||
referenceImageIPAdapterModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | null }, 'reference_image'>>
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
|
||||
>
|
||||
) => {
|
||||
const { entityIdentifier, modelConfig } = action.payload;
|
||||
const entity = selectEntity(state, entityIdentifier);
|
||||
@@ -631,12 +638,39 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
if (entity.ipAdapter.model?.base === 'flux') {
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-L';
|
||||
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
|
||||
// Fall back to ViT-H (ViT-G would also work)
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-H';
|
||||
|
||||
if (!entity.ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
|
||||
// Switching from ip_adapter to flux_redux
|
||||
entity.ipAdapter = {
|
||||
...initialFLUXRedux,
|
||||
image: entity.ipAdapter.image,
|
||||
model: entity.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
|
||||
// Switching from flux_redux to ip_adapter
|
||||
entity.ipAdapter = {
|
||||
...initialIPAdapter,
|
||||
image: entity.ipAdapter.image,
|
||||
model: entity.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (entity.ipAdapter.type === 'ip_adapter') {
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
if (entity.ipAdapter.model?.base === 'flux') {
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-L';
|
||||
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
|
||||
// Fall back to ViT-H (ViT-G would also work)
|
||||
entity.ipAdapter.clipVisionModel = 'ViT-H';
|
||||
}
|
||||
}
|
||||
},
|
||||
referenceImageIPAdapterCLIPVisionModelChanged: (
|
||||
@@ -648,6 +682,9 @@ export const canvasSlice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
if (entity.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
entity.ipAdapter.clipVisionModel = clipVisionModel;
|
||||
},
|
||||
referenceImageIPAdapterWeightChanged: (
|
||||
@@ -659,6 +696,9 @@ export const canvasSlice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
if (entity.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
entity.ipAdapter.weight = weight;
|
||||
},
|
||||
referenceImageIPAdapterBeginEndStepPctChanged: (
|
||||
@@ -670,6 +710,9 @@ export const canvasSlice = createSlice({
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
if (entity.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
entity.ipAdapter.beginEndStepPct = beginEndStepPct;
|
||||
},
|
||||
//#region Regional Guidance
|
||||
@@ -843,6 +886,10 @@ export const canvasSlice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
referenceImage.ipAdapter.weight = weight;
|
||||
},
|
||||
rgIPAdapterBeginEndStepPctChanged: (
|
||||
@@ -856,6 +903,10 @@ export const canvasSlice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
referenceImage.ipAdapter.beginEndStepPct = beginEndStepPct;
|
||||
},
|
||||
rgIPAdapterMethodChanged: (
|
||||
@@ -869,6 +920,10 @@ export const canvasSlice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
referenceImage.ipAdapter.method = method;
|
||||
},
|
||||
rgIPAdapterModelChanged: (
|
||||
@@ -877,7 +932,7 @@ export const canvasSlice = createSlice({
|
||||
EntityIdentifierPayload<
|
||||
{
|
||||
referenceImageId: string;
|
||||
modelConfig: IPAdapterModelConfig | null;
|
||||
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null;
|
||||
},
|
||||
'regional_guidance'
|
||||
>
|
||||
@@ -889,12 +944,39 @@ export const canvasSlice = createSlice({
|
||||
return;
|
||||
}
|
||||
referenceImage.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
if (referenceImage.ipAdapter.model?.base === 'flux') {
|
||||
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
|
||||
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
|
||||
// Fall back to ViT-H (ViT-G would also work)
|
||||
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
|
||||
|
||||
if (!referenceImage.ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (referenceImage.ipAdapter.type === 'ip_adapter' && referenceImage.ipAdapter.model.type === 'flux_redux') {
|
||||
// Switching from ip_adapter to flux_redux
|
||||
referenceImage.ipAdapter = {
|
||||
...initialFLUXRedux,
|
||||
image: referenceImage.ipAdapter.image,
|
||||
model: referenceImage.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (referenceImage.ipAdapter.type === 'flux_redux' && referenceImage.ipAdapter.model.type === 'ip_adapter') {
|
||||
// Switching from flux_redux to ip_adapter
|
||||
referenceImage.ipAdapter = {
|
||||
...initialIPAdapter,
|
||||
image: referenceImage.ipAdapter.image,
|
||||
model: referenceImage.ipAdapter.model,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (referenceImage.ipAdapter.type === 'ip_adapter') {
|
||||
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
|
||||
if (referenceImage.ipAdapter.model?.base === 'flux') {
|
||||
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
|
||||
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
|
||||
// Fall back to ViT-H (ViT-G would also work)
|
||||
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
|
||||
}
|
||||
}
|
||||
},
|
||||
rgIPAdapterCLIPVisionModelChanged: (
|
||||
@@ -908,6 +990,10 @@ export const canvasSlice = createSlice({
|
||||
if (!referenceImage) {
|
||||
return;
|
||||
}
|
||||
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
|
||||
return;
|
||||
}
|
||||
|
||||
referenceImage.ipAdapter.clipVisionModel = clipVisionModel;
|
||||
},
|
||||
//#region Inpaint mask
|
||||
|
||||
@@ -233,6 +233,13 @@ const zIPAdapterConfig = z.object({
|
||||
});
|
||||
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
|
||||
|
||||
const zFLUXReduxConfig = z.object({
|
||||
type: z.literal('flux_redux'),
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
});
|
||||
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
|
||||
|
||||
const zCanvasEntityBase = z.object({
|
||||
id: zId,
|
||||
name: zName,
|
||||
@@ -242,10 +249,16 @@ const zCanvasEntityBase = z.object({
|
||||
|
||||
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
|
||||
type: z.literal('reference_image'),
|
||||
ipAdapter: zIPAdapterConfig,
|
||||
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
|
||||
});
|
||||
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
|
||||
|
||||
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
|
||||
config.type === 'ip_adapter';
|
||||
|
||||
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
|
||||
config.type === 'flux_redux';
|
||||
|
||||
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
|
||||
export type FillStyle = z.infer<typeof zFillStyle>;
|
||||
export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success;
|
||||
@@ -253,7 +266,7 @@ const zFill = z.object({ style: zFillStyle, color: zRgbColor });
|
||||
|
||||
const zRegionalGuidanceReferenceImageState = z.object({
|
||||
id: zId,
|
||||
ipAdapter: zIPAdapterConfig,
|
||||
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
|
||||
});
|
||||
export type RegionalGuidanceReferenceImageState = z.infer<typeof zRegionalGuidanceReferenceImageState>;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import type {
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlLoRAConfig,
|
||||
ControlNetConfig,
|
||||
FLUXReduxConfig,
|
||||
ImageWithDims,
|
||||
IPAdapterConfig,
|
||||
RgbColor,
|
||||
@@ -70,6 +71,11 @@ export const initialIPAdapter: IPAdapterConfig = {
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
};
|
||||
export const initialFLUXRedux: FLUXReduxConfig = {
|
||||
type: 'flux_redux',
|
||||
image: null,
|
||||
model: null,
|
||||
};
|
||||
export const initialT2IAdapter: T2IAdapterConfig = {
|
||||
type: 't2i_adapter',
|
||||
model: null,
|
||||
|
||||
@@ -44,33 +44,33 @@ export const getRegionalGuidanceWarnings = (
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (model.base === 'flux') {
|
||||
return warnings;
|
||||
}
|
||||
|
||||
if (model.base === 'flux') {
|
||||
// Some features are not supported for flux models
|
||||
if (entity.negativePrompt !== null) {
|
||||
warnings.push(WARNINGS.RG_NEGATIVE_PROMPT_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.referenceImages.length > 0) {
|
||||
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
|
||||
}
|
||||
if (entity.autoNegative) {
|
||||
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
|
||||
}
|
||||
} else {
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
if (!ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
|
||||
if (!ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
entity.referenceImages.forEach(({ ipAdapter }) => {
|
||||
if (!ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
|
||||
if (!ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return warnings;
|
||||
@@ -82,22 +82,27 @@ export const getGlobalReferenceImageWarnings = (
|
||||
): WarningTKey[] => {
|
||||
const warnings: WarningTKey[] = [];
|
||||
|
||||
if (!entity.ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (model) {
|
||||
if (model) {
|
||||
if (model.base === 'sd-3' || model.base === 'sd-2') {
|
||||
// Unsupported model architecture
|
||||
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
|
||||
} else if (entity.ipAdapter.model.base !== model.base) {
|
||||
return warnings;
|
||||
}
|
||||
|
||||
const { ipAdapter } = entity;
|
||||
|
||||
if (!ipAdapter.model) {
|
||||
// No model selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
|
||||
} else if (ipAdapter.model.base !== model.base) {
|
||||
// Supported model architecture but doesn't match
|
||||
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
|
||||
}
|
||||
}
|
||||
|
||||
if (!entity.ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
if (!entity.ipAdapter.image) {
|
||||
// No image selected
|
||||
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
|
||||
@@ -6,7 +6,7 @@ import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { FluxReduxModelConfig } from 'services/api/types';
|
||||
import type { FLUXReduxModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -19,7 +19,7 @@ const FluxReduxModelFieldInputComponent = (
|
||||
const [modelConfigs, { isLoading }] = useFluxReduxModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: FluxReduxModelConfig | null) => {
|
||||
(value: FLUXReduxModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import type { CanvasReferenceImageState, FLUXReduxConfig } from 'features/controlLayers/store/types';
|
||||
import { isFLUXReduxConfig } from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type AddFLUXReduxResult = {
|
||||
addedFLUXReduxes: number;
|
||||
};
|
||||
|
||||
type AddFLUXReduxArg = {
|
||||
entities: CanvasReferenceImageState[];
|
||||
g: Graph;
|
||||
collector: Invocation<'collect'>;
|
||||
model: ParameterModel;
|
||||
};
|
||||
|
||||
export const addFLUXReduxes = ({ entities, g, collector, model }: AddFLUXReduxArg): AddFLUXReduxResult => {
|
||||
const validFLUXReduxes = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => isFLUXReduxConfig(entity.ipAdapter))
|
||||
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddFLUXReduxResult = {
|
||||
addedFLUXReduxes: 0,
|
||||
};
|
||||
|
||||
for (const { id, ipAdapter } of validFLUXReduxes) {
|
||||
assert(isFLUXReduxConfig(ipAdapter), 'This should have been filtered out');
|
||||
result.addedFLUXReduxes++;
|
||||
|
||||
addFLUXRedux(id, ipAdapter, g, collector);
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
const addFLUXRedux = (id: string, ipAdapter: FLUXReduxConfig, g: Graph, collector: Invocation<'collect'>) => {
|
||||
const { model: fluxReduxModel, image } = ipAdapter;
|
||||
assert(image, 'FLUX Redux image is required');
|
||||
assert(fluxReduxModel, 'FLUX Redux model is required');
|
||||
|
||||
const node = g.addNode({
|
||||
id: `flux_redux_${id}`,
|
||||
type: 'flux_redux',
|
||||
redux_model: fluxReduxModel,
|
||||
image: {
|
||||
image_name: image.image_name,
|
||||
},
|
||||
});
|
||||
|
||||
g.addEdge(node, 'redux_cond', collector, 'item');
|
||||
};
|
||||
@@ -1,4 +1,8 @@
|
||||
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
|
||||
import {
|
||||
type CanvasReferenceImageState,
|
||||
type IPAdapterConfig,
|
||||
isIPAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
@@ -19,23 +23,24 @@ type AddIPAdaptersArg = {
|
||||
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
|
||||
const validIPAdapters = entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.filter((entity) => isIPAdapterConfig(entity.ipAdapter))
|
||||
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
|
||||
|
||||
const result: AddIPAdaptersResult = {
|
||||
addedIPAdapters: 0,
|
||||
};
|
||||
|
||||
for (const ipa of validIPAdapters) {
|
||||
for (const { id, ipAdapter } of validIPAdapters) {
|
||||
assert(isIPAdapterConfig(ipAdapter), 'This should have been filtered out');
|
||||
result.addedIPAdapters++;
|
||||
|
||||
addIPAdapter(ipa, g, collector);
|
||||
addIPAdapter(id, ipAdapter, g, collector);
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: Invocation<'collect'>) => {
|
||||
const { id, ipAdapter } = entity;
|
||||
const addIPAdapter = (id: string, ipAdapter: IPAdapterConfig, g: Graph, collector: Invocation<'collect'>) => {
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(image, 'IP Adapter image is required');
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
@@ -18,6 +18,7 @@ type AddedRegionResult = {
|
||||
addedNegativePrompt: boolean;
|
||||
addedAutoNegativePositivePrompt: boolean;
|
||||
addedIPAdapters: number;
|
||||
addedFLUXReduxes: number;
|
||||
};
|
||||
|
||||
type AddRegionsArg = {
|
||||
@@ -31,6 +32,7 @@ type AddRegionsArg = {
|
||||
posCondCollect: Invocation<'collect'>;
|
||||
negCondCollect: Invocation<'collect'> | null;
|
||||
ipAdapterCollect: Invocation<'collect'>;
|
||||
fluxReduxCollect: Invocation<'collect'> | null;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -45,6 +47,7 @@ type AddRegionsArg = {
|
||||
* @param posCondCollect The positive conditioning collector
|
||||
* @param negCondCollect The negative conditioning collector
|
||||
* @param ipAdapterCollect The IP adapter collector
|
||||
* @param fluxReduxConnect The IP adapter collector
|
||||
* @returns A promise that resolves to the regions that were successfully added to the graph
|
||||
*/
|
||||
|
||||
@@ -59,6 +62,7 @@ export const addRegions = async ({
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
fluxReduxCollect,
|
||||
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
|
||||
const isSDXL = model.base === 'sdxl';
|
||||
const isFLUX = model.base === 'flux';
|
||||
@@ -75,6 +79,7 @@ export const addRegions = async ({
|
||||
addedNegativePrompt: false,
|
||||
addedAutoNegativePositivePrompt: false,
|
||||
addedIPAdapters: 0,
|
||||
addedFLUXReduxes: 0,
|
||||
};
|
||||
|
||||
const getImageDTOResult = await withResultAsync(() => {
|
||||
@@ -269,30 +274,52 @@ export const addRegions = async ({
|
||||
}
|
||||
|
||||
for (const { id, ipAdapter } of region.referenceImages) {
|
||||
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
|
||||
if (ipAdapter.type === 'ip_adapter') {
|
||||
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
|
||||
|
||||
result.addedIPAdapters++;
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
assert(image, 'IP Adapter image is required');
|
||||
result.addedIPAdapters++;
|
||||
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
|
||||
assert(model, 'IP Adapter model is required');
|
||||
assert(image, 'IP Adapter image is required');
|
||||
|
||||
const ipAdapterNode = g.addNode({
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
weight,
|
||||
method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.image_name,
|
||||
},
|
||||
});
|
||||
const ipAdapterNode = g.addNode({
|
||||
id: `ip_adapter_${id}`,
|
||||
type: 'ip_adapter',
|
||||
weight,
|
||||
method,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
image: {
|
||||
image_name: image.image_name,
|
||||
},
|
||||
});
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
|
||||
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
|
||||
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
|
||||
} else if (ipAdapter.type === 'flux_redux') {
|
||||
assert(isFLUX, 'Regional FLUX Redux requires FLUX.');
|
||||
assert(fluxReduxCollect !== null, 'FLUX Redux collector is required.');
|
||||
result.addedFLUXReduxes++;
|
||||
const { model: fluxReduxModel, image } = ipAdapter;
|
||||
assert(fluxReduxModel, 'FLUX Redux model is required');
|
||||
assert(image, 'FLUX Redux image is required');
|
||||
|
||||
const fluxReduxNode = g.addNode({
|
||||
id: `flux_redux_${id}`,
|
||||
type: 'flux_redux',
|
||||
redux_model: fluxReduxModel,
|
||||
image: {
|
||||
image_name: image.image_name,
|
||||
},
|
||||
});
|
||||
|
||||
// Connect the mask to the conditioning
|
||||
g.addEdge(maskToTensor, 'mask', fluxReduxNode, 'mask');
|
||||
g.addEdge(fluxReduxNode, 'redux_cond', fluxReduxCollect, 'item');
|
||||
}
|
||||
}
|
||||
|
||||
results.push(result);
|
||||
|
||||
@@ -7,6 +7,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
|
||||
import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux';
|
||||
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
@@ -233,6 +234,17 @@ export const buildFLUXGraph = async (
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const fluxReduxCollect = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('ip_adapter_collector'),
|
||||
});
|
||||
const fluxReduxResult = addFLUXReduxes({
|
||||
entities: canvas.referenceImages.entities,
|
||||
g,
|
||||
collector: fluxReduxCollect,
|
||||
model: modelConfig,
|
||||
});
|
||||
|
||||
const regionsResult = await addRegions({
|
||||
manager,
|
||||
regions: canvas.regionalGuidance.entities,
|
||||
@@ -244,6 +256,7 @@ export const buildFLUXGraph = async (
|
||||
posCondCollect,
|
||||
negCondCollect: null,
|
||||
ipAdapterCollect,
|
||||
fluxReduxCollect,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
@@ -254,6 +267,16 @@ export const buildFLUXGraph = async (
|
||||
g.deleteNode(ipAdapterCollect.id);
|
||||
}
|
||||
|
||||
const totalReduxesAdded =
|
||||
fluxReduxResult.addedFLUXReduxes + regionsResult.reduce((acc, r) => acc + r.addedFLUXReduxes, 0);
|
||||
if (totalReduxesAdded > 0) {
|
||||
g.addEdge(fluxReduxCollect, 'collection', denoise, 'redux_conditioning');
|
||||
} else {
|
||||
g.deleteNode(fluxReduxCollect.id);
|
||||
}
|
||||
|
||||
// TODO: Add FLUX Reduxes to denoise node like we do for ipa
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
@@ -281,6 +281,7 @@ export const buildSD1Graph = async (
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
fluxReduxCollect: null,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
|
||||
@@ -286,6 +286,7 @@ export const buildSDXLGraph = async (
|
||||
posCondCollect,
|
||||
negCondCollect,
|
||||
ipAdapterCollect,
|
||||
fluxReduxCollect: null,
|
||||
});
|
||||
|
||||
const totalIPAdaptersAdded =
|
||||
|
||||
@@ -19,7 +19,6 @@ import {
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSD3MainModelModelConfig,
|
||||
@@ -39,7 +38,7 @@ const buildModelsHook =
|
||||
typeGuard: (config: AnyModelConfig, excludeSubmodels?: boolean) => config is T,
|
||||
excludeSubmodels?: boolean
|
||||
) =>
|
||||
() => {
|
||||
(filter: (config: T) => boolean = () => true) => {
|
||||
const result = useGetModelConfigsQuery(undefined);
|
||||
const modelConfigs = useMemo(() => {
|
||||
if (!result.data) {
|
||||
@@ -48,13 +47,13 @@ const buildModelsHook =
|
||||
|
||||
return modelConfigsAdapterSelectors
|
||||
.selectAll(result.data)
|
||||
.filter((config) => typeGuard(config, excludeSubmodels));
|
||||
}, [result]);
|
||||
.filter((config) => typeGuard(config, excludeSubmodels))
|
||||
.filter(filter);
|
||||
}, [filter, result.data]);
|
||||
|
||||
return [modelConfigs, result] as const;
|
||||
};
|
||||
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
@@ -78,6 +77,9 @@ export const useFluxVAEModels = (args?: ModelHookArgs) =>
|
||||
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
|
||||
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
|
||||
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
|
||||
export const useIPAdapterOrFLUXReduxModels = buildModelsHook(
|
||||
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
|
||||
);
|
||||
|
||||
// const buildModelsSelector =
|
||||
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>
|
||||
|
||||
@@ -63,7 +63,7 @@ type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type SigLipModelConfig = S['SigLIPConfig'];
|
||||
export type FluxReduxModelConfig = S['FluxReduxConfig'];
|
||||
export type FLUXReduxModelConfig = S['FluxReduxConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
| ControlLoRAModelConfig
|
||||
@@ -80,7 +80,7 @@ export type AnyModelConfig =
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig
|
||||
| SigLipModelConfig
|
||||
| FluxReduxModelConfig;
|
||||
| FLUXReduxModelConfig;
|
||||
|
||||
/**
|
||||
* Checks if a list of submodels contains any that match a given variant or type
|
||||
@@ -217,7 +217,7 @@ export const isSigLipModelConfig = (config: AnyModelConfig): config is SigLipMod
|
||||
return config.type === 'siglip';
|
||||
};
|
||||
|
||||
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FluxReduxModelConfig => {
|
||||
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FLUXReduxModelConfig => {
|
||||
return config.type === 'flux_redux';
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user