First pass at frontend integration for FLUX Redux and SigLIP model types.

This commit is contained in:
Ryan Dick
2025-03-03 17:02:01 +00:00
committed by psychedelicious
parent 9c542ed655
commit 7b48ef2264
10 changed files with 148 additions and 1 deletions

View File

@@ -761,6 +761,7 @@
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"description": "Description",
"edit": "Edit",
"fluxRedux": "FLUX Redux",
"height": "Height",
"huggingFace": "HuggingFace",
"huggingFacePlaceholder": "owner/model-name",
@@ -835,6 +836,7 @@
"settings": "Settings",
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"sigLip": "SigLIP",
"spandrelImageToImage": "Image to Image (Spandrel)",
"starterBundles": "Starter Bundles",
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",

View File

@@ -14,10 +14,12 @@ import {
useControlLoRAModel,
useControlNetModels,
useEmbeddingModels,
useFluxReduxModels,
useIPAdapterModels,
useLoRAModels,
useMainModels,
useRefinerModels,
useSigLipModels,
useSpandrelImageToImageModels,
useT2IAdapterModels,
useT5EncoderModels,
@@ -112,6 +114,18 @@ const ModelList = () => {
[spandrelImageToImageModels, searchTerm, filteredModelType]
);
const [sigLipModels, { isLoading: isLoadingSigLipModels }] = useSigLipModels();
const filteredSigLipModels = useMemo(
() => modelsFilter(sigLipModels, searchTerm, filteredModelType),
[sigLipModels, searchTerm, filteredModelType]
);
const [fluxReduxModels, { isLoading: isLoadingFluxReduxModels }] = useFluxReduxModels();
const filteredFluxReduxModels = useMemo(
() => modelsFilter(fluxReduxModels, searchTerm, filteredModelType),
[fluxReduxModels, searchTerm, filteredModelType]
);
const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
@@ -124,6 +138,8 @@ const ModelList = () => {
filteredCLIPVisionModels.length +
filteredVAEModels.length +
filteredSpandrelImageToImageModels.length +
filteredSigLipModels.length +
filteredFluxReduxModels.length +
t5EncoderModels.length +
clipEmbedModels.length +
controlLoRAModels.length
@@ -139,6 +155,8 @@ const ModelList = () => {
filteredT2IAdapterModels.length,
filteredVAEModels.length,
filteredSpandrelImageToImageModels.length,
filteredSigLipModels.length,
filteredFluxReduxModels.length,
t5EncoderModels.length,
clipEmbedModels.length,
controlLoRAModels.length,
@@ -229,6 +247,16 @@ const ModelList = () => {
key="spandrel-image-to-image"
/>
)}
{/* SigLIP List */}
{isLoadingSigLipModels && <FetchingModelsLoader loadingMessage="Loading SigLIP Models..." />}
{!isLoadingSigLipModels && filteredSigLipModels.length > 0 && (
<ModelListWrapper title={t('modelManager.sigLip')} modelList={filteredSigLipModels} key="sig-lip" />
)}
{/* Flux Redux List */}
{isLoadingFluxReduxModels && <FetchingModelsLoader loadingMessage="Loading Flux Redux Models..." />}
{!isLoadingFluxReduxModels && filteredFluxReduxModels.length > 0 && (
<ModelListWrapper title={t('modelManager.fluxRedux')} modelList={filteredFluxReduxModels} key="flux-redux" />
)}
{totalFilteredModels === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>{t('modelManager.noMatchingModels')}</Text>

View File

@@ -25,6 +25,8 @@ export const ModelTypeFilter = memo(() => {
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
control_lora: t('modelManager.controlLora'),
siglip: t('modelManager.siglip'),
flux_redux: t('modelManager.fluxRedux'),
}),
[t]
);

View File

@@ -18,6 +18,7 @@ import type {
FieldValue,
FloatFieldValue,
FloatGeneratorFieldValue,
FluxReduxModelFieldValue,
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
ImageFieldValue,
@@ -31,6 +32,7 @@ import type {
ModelIdentifierFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
SigLipModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
@@ -53,6 +55,7 @@ import {
zFloatFieldCollectionValue,
zFloatFieldValue,
zFloatGeneratorFieldValue,
zFluxReduxModelFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
zImageFieldValue,
@@ -66,6 +69,7 @@ import {
zModelIdentifierFieldValue,
zSchedulerFieldValue,
zSDXLRefinerModelFieldValue,
zSigLipModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
@@ -409,6 +413,12 @@ export const nodesSlice = createSlice({
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
fieldSigLipModelValueChanged: (state, action: FieldValueAction<SigLipModelFieldValue>) => {
fieldValueReducer(state, action, zSigLipModelFieldValue);
},
fieldFluxReduxModelValueChanged: (state, action: FieldValueAction<FluxReduxModelFieldValue>) => {
fieldValueReducer(state, action, zFluxReduxModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -516,6 +526,8 @@ export const {
fieldCLIPGEmbedValueChanged,
fieldControlLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,

View File

@@ -79,6 +79,8 @@ const zModelType = z.enum([
'spandrel_image_to_image',
't5_encoder',
'clip_embed',
'siglip',
'flux_redux',
]);
const zSubModelType = z.enum([
'unet',

View File

@@ -229,6 +229,14 @@ const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSigLipModelFieldType = zFieldTypeBase.extend({
name: z.literal('SigLipModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxReduxModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxReduxModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -275,6 +283,8 @@ const zStatefulFieldType = z.union([
zCLIPGEmbedModelFieldType,
zControlLoRAModelFieldType,
zFluxVAEModelFieldType,
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zColorFieldType,
zSchedulerFieldType,
zFloatGeneratorFieldType,
@@ -309,6 +319,8 @@ const modelFieldTypeNames = [
zCLIPGEmbedModelFieldType.shape.name.value,
zControlLoRAModelFieldType.shape.name.value,
zFluxVAEModelFieldType.shape.name.value,
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
'VAEField',
@@ -1074,6 +1086,42 @@ export const isControlLoRAModelFieldInputTemplate =
buildTemplateTypeGuard<ControlLoRAModelFieldInputTemplate>('ControlLoRAModelField');
// #endregion
// #region SigLipModelField
export const zSigLipModelFieldValue = zModelIdentifierField.optional();
const zSigLipModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSigLipModelFieldValue,
});
const zSigLipModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSigLipModelFieldType,
originalType: zFieldType.optional(),
default: zSigLipModelFieldValue,
});
export type SigLipModelFieldValue = z.infer<typeof zSigLipModelFieldValue>;
export type SigLipModelFieldInputInstance = z.infer<typeof zSigLipModelFieldInputInstance>;
export type SigLipModelFieldInputTemplate = z.infer<typeof zSigLipModelFieldInputTemplate>;
export const isSigLipModelFieldInputInstance = buildInstanceTypeGuard(zSigLipModelFieldInputInstance);
export const isSigLipModelFieldInputTemplate =
buildTemplateTypeGuard<SigLipModelFieldInputTemplate>('SigLipModelField');
// #endregion
// #region FluxReduxModelField
export const zFluxReduxModelFieldValue = zModelIdentifierField.optional();
const zFluxReduxModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxReduxModelFieldValue,
});
const zFluxReduxModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxReduxModelFieldType,
originalType: zFieldType.optional(),
default: zFluxReduxModelFieldValue,
});
export type FluxReduxModelFieldValue = z.infer<typeof zFluxReduxModelFieldValue>;
export type FluxReduxModelFieldInputInstance = z.infer<typeof zFluxReduxModelFieldInputInstance>;
export type FluxReduxModelFieldInputTemplate = z.infer<typeof zFluxReduxModelFieldInputTemplate>;
export const isFluxReduxModelFieldInputInstance = buildInstanceTypeGuard(zFluxReduxModelFieldInputInstance);
export const isFluxReduxModelFieldInputTemplate =
buildTemplateTypeGuard<FluxReduxModelFieldInputTemplate>('FluxReduxModelField');
// #endregion
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1701,6 +1749,8 @@ export const zStatefulFieldValue = z.union([
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zControlLoRAModelFieldValue,
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
zFloatGeneratorFieldValue,
@@ -1785,6 +1835,8 @@ const zStatefulFieldInputTemplate = z.union([
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zControlLoRAModelFieldInputTemplate,
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -29,6 +29,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
ControlLoRAModelField: undefined,
SigLipModelField: undefined,
FluxReduxModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,
StringGeneratorField: undefined,

View File

@@ -15,6 +15,7 @@ import type {
FloatFieldInputTemplate,
FloatGeneratorFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxReduxModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldCollectionInputTemplate,
ImageFieldInputTemplate,
@@ -30,6 +31,7 @@ import type {
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
SigLipModelFieldInputTemplate,
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
@@ -527,6 +529,33 @@ const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilde
return template;
};
const buildSigLipModelFieldInputTemplate: FieldInputTemplateBuilder<SigLipModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SigLipModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxReduxModelFieldInputTemplate: FieldInputTemplateBuilder<FluxReduxModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxReduxModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -729,6 +758,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
SigLipModelField: buildSigLipModelFieldInputTemplate,
FluxReduxModelField: buildFluxReduxModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
StringGeneratorField: buildStringGeneratorFieldInputTemplate,

View File

@@ -15,6 +15,7 @@ import {
isControlLoRAModelConfig,
isControlNetModelConfig,
isFluxMainModelModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
@@ -23,6 +24,7 @@ import {
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSigLipModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig,
isT5EncoderModelConfig,
@@ -74,6 +76,8 @@ export const useVAEModels = (args?: ModelHookArgs) => buildModelsHook(isVAEModel
export const useFluxVAEModels = (args?: ModelHookArgs) =>
buildModelsHook(isFluxVAEModelConfig, args?.excludeSubmodels)();
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
export const useSigLipModels = buildModelsHook(isSigLipModelConfig);
export const useFluxReduxModels = buildModelsHook(isFluxReduxModelConfig);
// const buildModelsSelector =
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>

View File

@@ -62,6 +62,8 @@ type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualI
type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type SigLipModelConfig = S['SigLIPConfig'];
export type FluxReduxModelConfig = S['FluxReduxConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| ControlLoRAModelConfig
@@ -76,7 +78,9 @@ export type AnyModelConfig =
| SpandrelImageToImageModelConfig
| TextualInversionModelConfig
| MainModelConfig
| CLIPVisionDiffusersConfig;
| CLIPVisionDiffusersConfig
| SigLipModelConfig
| FluxReduxModelConfig;
/**
* Checks if a list of submodels contains any that match a given variant or type
@@ -209,6 +213,14 @@ export const isSpandrelImageToImageModelConfig = (
return config.type === 'spandrel_image_to_image';
};
export const isSigLipModelConfig = (config: AnyModelConfig): config is SigLipModelConfig => {
return config.type === 'siglip';
};
export const isFluxReduxModelConfig = (config: AnyModelConfig): config is FluxReduxModelConfig => {
return config.type === 'flux_redux';
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};