feat(ui): support for FLUX Redux in canvas

User facing:

When a FLUX main model is selected, users may now add Regional Reference Image layers.

When switching between FLUX Redux and FLUX IP Adapter, the settings will change to match the model type. (IP Adapter has weight, begin/end step, but Redux does not.) The image will be retained when switching between the two.

Otherwise it works the same way as IP Adapter - both in Global and Regional Reference Image layers.

---

Internal state handling:

Slightly awkward, but it was easiest to make FLUX Redux a second type of IP Adapter in redux state.

Global and regional reference images still have a single `ipAdapter` field, but it can have a type of `ip_adapter` or `flux_redux`.

Ideally, this field is called `config` or `settings` or something, but we are past that point. We _could_ do a migration to rename it, but I don't think it's worth the effort.

---

Other changes:
- Updated canvas layer validators to handle FLUX Redux.
- Updated model list loading logic to un-set FLUX Redux models in Canvas if they are not in the list (e.g. if the user deletes the model in the main app).
- Updated graph builders - new `addFLUXRedux` util & updated `addRegions` util.
- Updated the `buildModelsHook` util to return a hook that accepts a filter callback. This handles a discrepancy: FLUX IP Adapter does not support regional guidance, but FLUX Redux does. The Regional Guidance settings provide the filter to filter out FLUX IP Adapter models from the combined list of IP Adapter ahd Redux models.
This commit is contained in:
psychedelicious
2025-03-07 16:35:11 +10:00
parent f62b9ad919
commit c259899bf4
20 changed files with 494 additions and 189 deletions

View File

@@ -31,6 +31,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlLayerModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
@@ -77,6 +78,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleT5EncoderModels(models, state, dispatch, log);
handleCLIPEmbedModels(models, state, dispatch, log);
handleFLUXVAEModels(models, state, dispatch, log);
handleFLUXReduxModels(models, state, dispatch, log);
},
});
};
@@ -209,6 +211,10 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
const selectedIPAdapterModel = entity.ipAdapter.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
@@ -224,6 +230,10 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'ip_adapter') {
return;
}
const selectedIPAdapterModel = ipAdapter.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
@@ -241,6 +251,49 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
});
};
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'flux_redux') {
return;
}
const selectedFLUXReduxModel = entity.ipAdapter.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
}
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
if (isModelAvailable) {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'flux_redux') {
return;
}
const selectedFLUXReduxModel = ipAdapter.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
}
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
if (isModelAvailable) {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
);
});
});
};
const handlePostProcessingModel: ModelHandler = (models, state, dispatch, log) => {
const selectedPostProcessingModel = state.upscale.postProcessingModel;
const allSpandrelModels = models.filter(isSpandrelImageToImageModelConfig);

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -22,7 +22,6 @@ export const CanvasAddEntityButtons = memo(() => {
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
@@ -75,7 +74,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX || isSD3}
isDisabled={isSD3}
>
{t('controlLayers.regionalReferenceImage')}
</Button>

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -23,7 +23,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
@@ -52,7 +51,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isSD3}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>

View File

@@ -0,0 +1,61 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
const FLUX_CLIP_VISION = 'ViT-L';
const CLIP_VISION_OPTIONS = [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
];
type Props = {
model: CLIPVisionModelV2;
onChange: (clipVisionModel: CLIPVisionModelV2) => void;
};
export const CLIPVisionModel = memo(({ model, onChange }: Props) => {
const { t } = useTranslation();
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
(v) => {
assert(isCLIPVisionModelV2(v?.value));
onChange(v.value);
},
[onChange]
);
const isFLUX = useAppSelector(selectIsFLUX);
const clipVisionOptions = useMemo(() => {
return CLIP_VISION_OPTIONS.map((option) => ({
...option,
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
}));
}, [isFLUX]);
const clipVisionModelValue = useMemo(() => {
return CLIP_VISION_OPTIONS.find((o) => o.value === model);
}, [model]);
return (
<FormControl width="max-content" minWidth={28}>
<Combobox
options={clipVisionOptions}
placeholder={t('common.placeholderSelectAModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}
/>
</FormControl>
);
});
CLIPVisionModel.displayName = 'CLIPVisionModel';

View File

@@ -1,40 +1,36 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
const FLUX_CLIP_VISION = 'ViT-L';
const CLIP_VISION_OPTIONS = [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
];
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
isRegionalGuidance: boolean;
modelKey: string | null;
onChangeModel: (modelConfig: IPAdapterModelConfig) => void;
clipVisionModel: CLIPVisionModelV2;
onChangeCLIPVisionModel: (clipVisionModel: CLIPVisionModelV2) => void;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
};
export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel, onChangeCLIPVisionModel }: Props) => {
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const filter = useCallback(
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
return false;
}
return true;
},
[isRegionalGuidance]
);
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | null) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null) => {
if (!modelConfig) {
return;
}
@@ -43,21 +39,11 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
[onChangeModel]
);
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
(v) => {
assert(isCLIPVisionModelV2(v?.value));
onChangeCLIPVisionModel(v.value);
},
[onChangeCLIPVisionModel]
);
const isFLUX = useAppSelector(selectIsFLUX);
const getIsDisabled = useCallback(
(model: AnyModelConfig): boolean => {
const isCompatible = currentBaseModel === model.base;
const hasMainModel = Boolean(currentBaseModel);
return !hasMainModel || !isCompatible;
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[currentBaseModel]
);
@@ -70,41 +56,18 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
isLoading,
});
const clipVisionOptions = useMemo(() => {
return CLIP_VISION_OPTIONS.map((option) => ({
...option,
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
}));
}, [isFLUX]);
const clipVisionModelValue = useMemo(() => {
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
}, [clipVisionModel]);
return (
<Flex gap={2}>
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
{selectedModel?.format === 'checkpoint' && (
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
<Combobox
options={clipVisionOptions}
placeholder={t('common.placeholderSelectAModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}
/>
</FormControl>
)}
</Flex>
<Tooltip label={selectedModel?.description}>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
);
});

View File

@@ -1,9 +1,10 @@
import { Box, Flex, IconButton } from '@invoke-ai/ui-library';
import { Flex, IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -25,7 +26,7 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
import { IPAdapterModel } from './IPAdapterModel';
@@ -65,7 +66,7 @@ const IPAdapterSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
},
[dispatch, entityIdentifier]
@@ -98,14 +99,14 @@ const IPAdapterSettingsContent = memo(() => {
<CanvasEntitySettingsWrapper>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<IPAdapterModel
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
clipVisionModel={ipAdapter.clipVisionModel}
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
/>
</Box>
<IPAdapterModel
isRegionalGuidance={false}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
@@ -116,12 +117,14 @@ const IPAdapterSettingsContent = memo(() => {
/>
</Flex>
<Flex gap={2} w="full" alignItems="center">
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
{!isFLUX && <IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />}
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}

View File

@@ -1,8 +1,9 @@
import { Box, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
@@ -26,7 +27,7 @@ import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiXBold } from 'react-icons/pi';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type Props = {
@@ -73,7 +74,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
dispatch(rgIPAdapterModelChanged({ entityIdentifier, referenceImageId, modelConfig }));
},
[dispatch, entityIdentifier, referenceImageId]
@@ -125,14 +126,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
</Flex>
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex gap={2} alignItems="center" w="full">
<Box minW={0} w="full" transitionProperty="common" transitionDuration="0.1s">
<IPAdapterModel
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
clipVisionModel={ipAdapter.clipVisionModel}
onChangeCLIPVisionModel={onChangeCLIPVisionModel}
/>
</Box>
<IPAdapterModel
isRegionalGuidance={true}
modelKey={ipAdapter.model?.key ?? null}
onChangeModel={onChangeModel}
/>
{ipAdapter.type === 'ip_adapter' && (
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
)}
<IconButton
onClick={pullBboxIntoIPAdapter}
isDisabled={isBusy}
@@ -143,12 +144,14 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
/>
</Flex>
<Flex gap={2} w="full">
<Flex flexDir="column" gap={2} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1">
{ipAdapter.type === 'ip_adapter' && (
<Flex flexDir="column" gap={2} w="full">
<IPAdapterMethod method={ipAdapter.method} onChange={onChangeIPMethod} />
<Weight weight={ipAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={ipAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
</Flex>
)}
<Flex alignItems="center" justifyContent="center" h={32} w={32} aspectRatio="1/1" flexGrow={1}>
<IPAdapterImagePreview
image={ipAdapter.image}
onChangeImage={onChangeImage}

View File

@@ -38,6 +38,7 @@ import type { UndoableOptions } from 'redux-undo';
import type {
ControlLoRAModelConfig,
ControlNetModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
T2IAdapterModelConfig,
@@ -76,6 +77,7 @@ import {
imageDTOToImageWithDims,
initialControlLoRA,
initialControlNet,
initialFLUXRedux,
initialIPAdapter,
initialT2IAdapter,
} from './util';
@@ -619,11 +621,16 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.method = method;
},
referenceImageIPAdapterModelChanged: (
state,
action: PayloadAction<EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | null }, 'reference_image'>>
action: PayloadAction<
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
>
) => {
const { entityIdentifier, modelConfig } = action.payload;
const entity = selectEntity(state, entityIdentifier);
@@ -631,12 +638,39 @@ export const canvasSlice = createSlice({
return;
}
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
if (!entity.ipAdapter.model) {
return;
}
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
entity.ipAdapter = {
...initialFLUXRedux,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
entity.ipAdapter = {
...initialIPAdapter,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}
if (entity.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (entity.ipAdapter.model?.base === 'flux') {
entity.ipAdapter.clipVisionModel = 'ViT-L';
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
entity.ipAdapter.clipVisionModel = 'ViT-H';
}
}
},
referenceImageIPAdapterCLIPVisionModelChanged: (
@@ -648,6 +682,9 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.clipVisionModel = clipVisionModel;
},
referenceImageIPAdapterWeightChanged: (
@@ -659,6 +696,9 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.weight = weight;
},
referenceImageIPAdapterBeginEndStepPctChanged: (
@@ -670,6 +710,9 @@ export const canvasSlice = createSlice({
if (!entity) {
return;
}
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}
entity.ipAdapter.beginEndStepPct = beginEndStepPct;
},
//#region Regional Guidance
@@ -843,6 +886,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.weight = weight;
},
rgIPAdapterBeginEndStepPctChanged: (
@@ -856,6 +903,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.beginEndStepPct = beginEndStepPct;
},
rgIPAdapterMethodChanged: (
@@ -869,6 +920,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.method = method;
},
rgIPAdapterModelChanged: (
@@ -877,7 +932,7 @@ export const canvasSlice = createSlice({
EntityIdentifierPayload<
{
referenceImageId: string;
modelConfig: IPAdapterModelConfig | null;
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null;
},
'regional_guidance'
>
@@ -889,12 +944,39 @@ export const canvasSlice = createSlice({
return;
}
referenceImage.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (referenceImage.ipAdapter.model?.base === 'flux') {
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
if (!referenceImage.ipAdapter.model) {
return;
}
if (referenceImage.ipAdapter.type === 'ip_adapter' && referenceImage.ipAdapter.model.type === 'flux_redux') {
// Switching from ip_adapter to flux_redux
referenceImage.ipAdapter = {
...initialFLUXRedux,
image: referenceImage.ipAdapter.image,
model: referenceImage.ipAdapter.model,
};
return;
}
if (referenceImage.ipAdapter.type === 'flux_redux' && referenceImage.ipAdapter.model.type === 'ip_adapter') {
// Switching from flux_redux to ip_adapter
referenceImage.ipAdapter = {
...initialIPAdapter,
image: referenceImage.ipAdapter.image,
model: referenceImage.ipAdapter.model,
};
return;
}
if (referenceImage.ipAdapter.type === 'ip_adapter') {
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
if (referenceImage.ipAdapter.model?.base === 'flux') {
referenceImage.ipAdapter.clipVisionModel = 'ViT-L';
} else if (referenceImage.ipAdapter.clipVisionModel === 'ViT-L') {
// Fall back to ViT-H (ViT-G would also work)
referenceImage.ipAdapter.clipVisionModel = 'ViT-H';
}
}
},
rgIPAdapterCLIPVisionModelChanged: (
@@ -908,6 +990,10 @@ export const canvasSlice = createSlice({
if (!referenceImage) {
return;
}
if (referenceImage.ipAdapter.type !== 'ip_adapter') {
return;
}
referenceImage.ipAdapter.clipVisionModel = clipVisionModel;
},
//#region Inpaint mask

View File

@@ -233,6 +233,13 @@ const zIPAdapterConfig = z.object({
});
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
const zFLUXReduxConfig = z.object({
type: z.literal('flux_redux'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type FLUXReduxConfig = z.infer<typeof zFLUXReduxConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -242,10 +249,16 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
ipAdapter: zIPAdapterConfig,
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;
export const isIPAdapterConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is IPAdapterConfig =>
config.type === 'ip_adapter';
export const isFLUXReduxConfig = (config: IPAdapterConfig | FLUXReduxConfig): config is FLUXReduxConfig =>
config.type === 'flux_redux';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success;
@@ -253,7 +266,7 @@ const zFill = z.object({ style: zFillStyle, color: zRgbColor });
const zRegionalGuidanceReferenceImageState = z.object({
id: zId,
ipAdapter: zIPAdapterConfig,
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig]),
});
export type RegionalGuidanceReferenceImageState = z.infer<typeof zRegionalGuidanceReferenceImageState>;

View File

@@ -9,6 +9,7 @@ import type {
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
FLUXReduxConfig,
ImageWithDims,
IPAdapterConfig,
RgbColor,
@@ -70,6 +71,11 @@ export const initialIPAdapter: IPAdapterConfig = {
clipVisionModel: 'ViT-H',
weight: 1,
};
export const initialFLUXRedux: FLUXReduxConfig = {
type: 'flux_redux',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -44,33 +44,33 @@ export const getRegionalGuidanceWarnings = (
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (model.base === 'flux') {
return warnings;
}
if (model.base === 'flux') {
// Some features are not supported for flux models
if (entity.negativePrompt !== null) {
warnings.push(WARNINGS.RG_NEGATIVE_PROMPT_NOT_SUPPORTED);
}
if (entity.referenceImages.length > 0) {
warnings.push(WARNINGS.RG_REFERENCE_IMAGES_NOT_SUPPORTED);
}
if (entity.autoNegative) {
warnings.push(WARNINGS.RG_AUTO_NEGATIVE_NOT_SUPPORTED);
}
} else {
entity.referenceImages.forEach(({ ipAdapter }) => {
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
if (!ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
});
}
entity.referenceImages.forEach(({ ipAdapter }) => {
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
if (!ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
});
}
return warnings;
@@ -82,22 +82,27 @@ export const getGlobalReferenceImageWarnings = (
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
if (!entity.ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (model) {
if (model) {
if (model.base === 'sd-3' || model.base === 'sd-2') {
// Unsupported model architecture
warnings.push(WARNINGS.UNSUPPORTED_MODEL);
} else if (entity.ipAdapter.model.base !== model.base) {
return warnings;
}
const { ipAdapter } = entity;
if (!ipAdapter.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
} else if (ipAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.IP_ADAPTER_INCOMPATIBLE_BASE_MODEL);
}
}
if (!entity.ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
if (!entity.ipAdapter.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
}
return warnings;

View File

@@ -6,7 +6,7 @@ import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
import type { FluxReduxModelConfig } from 'services/api/types';
import type { FLUXReduxModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@@ -19,7 +19,7 @@ const FluxReduxModelFieldInputComponent = (
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const _onChange = useCallback(
(value: FluxReduxModelConfig | null) => {
(value: FLUXReduxModelConfig | null) => {
if (!value) {
return;
}

View File

@@ -0,0 +1,55 @@
import type { CanvasReferenceImageState, FLUXReduxConfig } from 'features/controlLayers/store/types';
import { isFLUXReduxConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
type AddFLUXReduxResult = {
addedFLUXReduxes: number;
};
type AddFLUXReduxArg = {
entities: CanvasReferenceImageState[];
g: Graph;
collector: Invocation<'collect'>;
model: ParameterModel;
};
export const addFLUXReduxes = ({ entities, g, collector, model }: AddFLUXReduxArg): AddFLUXReduxResult => {
const validFLUXReduxes = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isFLUXReduxConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddFLUXReduxResult = {
addedFLUXReduxes: 0,
};
for (const { id, ipAdapter } of validFLUXReduxes) {
assert(isFLUXReduxConfig(ipAdapter), 'This should have been filtered out');
result.addedFLUXReduxes++;
addFLUXRedux(id, ipAdapter, g, collector);
}
return result;
};
const addFLUXRedux = (id: string, ipAdapter: FLUXReduxConfig, g: Graph, collector: Invocation<'collect'>) => {
const { model: fluxReduxModel, image } = ipAdapter;
assert(image, 'FLUX Redux image is required');
assert(fluxReduxModel, 'FLUX Redux model is required');
const node = g.addNode({
id: `flux_redux_${id}`,
type: 'flux_redux',
redux_model: fluxReduxModel,
image: {
image_name: image.image_name,
},
});
g.addEdge(node, 'redux_cond', collector, 'item');
};

View File

@@ -1,4 +1,8 @@
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
import {
type CanvasReferenceImageState,
type IPAdapterConfig,
isIPAdapterConfig,
} from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
@@ -19,23 +23,24 @@ type AddIPAdaptersArg = {
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
const validIPAdapters = entities
.filter((entity) => entity.isEnabled)
.filter((entity) => isIPAdapterConfig(entity.ipAdapter))
.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddIPAdaptersResult = {
addedIPAdapters: 0,
};
for (const ipa of validIPAdapters) {
for (const { id, ipAdapter } of validIPAdapters) {
assert(isIPAdapterConfig(ipAdapter), 'This should have been filtered out');
result.addedIPAdapters++;
addIPAdapter(ipa, g, collector);
addIPAdapter(id, ipAdapter, g, collector);
}
return result;
};
const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: Invocation<'collect'>) => {
const { id, ipAdapter } = entity;
const addIPAdapter = (id: string, ipAdapter: IPAdapterConfig, g: Graph, collector: Invocation<'collect'>) => {
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');

View File

@@ -18,6 +18,7 @@ type AddedRegionResult = {
addedNegativePrompt: boolean;
addedAutoNegativePositivePrompt: boolean;
addedIPAdapters: number;
addedFLUXReduxes: number;
};
type AddRegionsArg = {
@@ -31,6 +32,7 @@ type AddRegionsArg = {
posCondCollect: Invocation<'collect'>;
negCondCollect: Invocation<'collect'> | null;
ipAdapterCollect: Invocation<'collect'>;
fluxReduxCollect: Invocation<'collect'> | null;
};
/**
@@ -45,6 +47,7 @@ type AddRegionsArg = {
* @param posCondCollect The positive conditioning collector
* @param negCondCollect The negative conditioning collector
* @param ipAdapterCollect The IP adapter collector
* @param fluxReduxConnect The IP adapter collector
* @returns A promise that resolves to the regions that were successfully added to the graph
*/
@@ -59,6 +62,7 @@ export const addRegions = async ({
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';
@@ -75,6 +79,7 @@ export const addRegions = async ({
addedNegativePrompt: false,
addedAutoNegativePositivePrompt: false,
addedIPAdapters: 0,
addedFLUXReduxes: 0,
};
const getImageDTOResult = await withResultAsync(() => {
@@ -269,30 +274,52 @@ export const addRegions = async ({
}
for (const { id, ipAdapter } of region.referenceImages) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
if (ipAdapter.type === 'ip_adapter') {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
assert(image, 'IP Adapter image is required');
const ipAdapterNode = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.image_name,
},
});
const ipAdapterNode = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: image.image_name,
},
});
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', ipAdapterNode, 'mask');
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
} else if (ipAdapter.type === 'flux_redux') {
assert(isFLUX, 'Regional FLUX Redux requires FLUX.');
assert(fluxReduxCollect !== null, 'FLUX Redux collector is required.');
result.addedFLUXReduxes++;
const { model: fluxReduxModel, image } = ipAdapter;
assert(fluxReduxModel, 'FLUX Redux model is required');
assert(image, 'FLUX Redux image is required');
const fluxReduxNode = g.addNode({
id: `flux_redux_${id}`,
type: 'flux_redux',
redux_model: fluxReduxModel,
image: {
image_name: image.image_name,
},
});
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', fluxReduxNode, 'mask');
g.addEdge(fluxReduxNode, 'redux_cond', fluxReduxCollect, 'item');
}
}
results.push(result);

View File

@@ -7,6 +7,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs';
import { addFLUXReduxes } from 'features/nodes/util/graph/generation/addFLUXRedux';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
@@ -233,6 +234,17 @@ export const buildFLUXGraph = async (
model: modelConfig,
});
const fluxReduxCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
const fluxReduxResult = addFLUXReduxes({
entities: canvas.referenceImages.entities,
g,
collector: fluxReduxCollect,
model: modelConfig,
});
const regionsResult = await addRegions({
manager,
regions: canvas.regionalGuidance.entities,
@@ -244,6 +256,7 @@ export const buildFLUXGraph = async (
posCondCollect,
negCondCollect: null,
ipAdapterCollect,
fluxReduxCollect,
});
const totalIPAdaptersAdded =
@@ -254,6 +267,16 @@ export const buildFLUXGraph = async (
g.deleteNode(ipAdapterCollect.id);
}
const totalReduxesAdded =
fluxReduxResult.addedFLUXReduxes + regionsResult.reduce((acc, r) => acc + r.addedFLUXReduxes, 0);
if (totalReduxesAdded > 0) {
g.addEdge(fluxReduxCollect, 'collection', denoise, 'redux_conditioning');
} else {
g.deleteNode(fluxReduxCollect.id);
}
// TODO: Add FLUX Reduxes to denoise node like we do for ipa
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}

View File

@@ -281,6 +281,7 @@ export const buildSD1Graph = async (
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect: null,
});
const totalIPAdaptersAdded =

View File

@@ -286,6 +286,7 @@ export const buildSDXLGraph = async (
posCondCollect,
negCondCollect,
ipAdapterCollect,
fluxReduxCollect: null,
});
const totalIPAdaptersAdded =

View File

@@ -19,7 +19,6 @@ import {
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
@@ -39,7 +38,7 @@ const buildModelsHook =
typeGuard: (config: AnyModelConfig, excludeSubmodels?: boolean) => config is T,
excludeSubmodels?: boolean
) =>
() => {
(filter: (config: T) => boolean = () => true) => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
@@ -48,13 +47,13 @@ const buildModelsHook =
return modelConfigsAdapterSelectors
.selectAll(result.data)
.filter((config) => typeGuard(config, excludeSubmodels));
}, [result]);
.filter((config) => typeGuard(config, excludeSubmodels))
.filter(filter);
}, [filter, result.data]);
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
@@ -78,6 +77,9 @@ export const useFluxVAEModels = (args?: ModelHookArgs) =>
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
export const useIPAdapterOrFLUXReduxModels = buildModelsHook(
(config) => isIPAdapterModelConfig(config) || isFluxReduxModelConfig(config)
);
// const buildModelsSelector =
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>

View File

@@ -63,7 +63,7 @@ type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type SigLipModelConfig = S['SigLIPConfig'];
export type FluxReduxModelConfig = S['FluxReduxConfig'];
export type FLUXReduxModelConfig = S['FluxReduxConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| ControlLoRAModelConfig
@@ -80,7 +80,7 @@ export type AnyModelConfig =
| MainModelConfig
| CLIPVisionDiffusersConfig
| SigLipModelConfig
| FluxReduxModelConfig;
| FLUXReduxModelConfig;
/**
* Checks if a list of submodels contains any that match a given variant or type
@@ -217,7 +217,7 @@ export const isSigLipModelConfig = (config: AnyModelConfig): config is SigLipMod
return config.type === 'siglip';
};
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FluxReduxModelConfig => {
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FLUXReduxModelConfig => {
return config.type === 'flux_redux';
};