Add model selecton fields to the FluxReduxInvocation.

This commit is contained in:
Ryan Dick
2025-03-03 17:30:52 +00:00
committed by psychedelicious
parent 7b48ef2264
commit b6b21dbcbf
5 changed files with 142 additions and 31 deletions

View File

@@ -57,6 +57,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
# endregion
# region Misc Field Types

View File

@@ -16,15 +16,12 @@ from invokeai.app.invocations.fields import (
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelType,
)
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice
@@ -59,6 +56,16 @@ class FluxReduxInvocation(BaseInvocation):
description="The bool mask associated with this FLUX Redux image prompt. Excluded regions should be set to "
"False, included regions should be set to True.",
)
redux_model: ModelIdentifierField = InputField(
description="The FLUX Redux model to use.",
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
)
siglip_model: ModelIdentifierField = InputField(
description="The SigLIP model to use.",
title="SigLIP Model",
ui_type=UIType.SigLipModel,
)
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
@@ -73,19 +80,7 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
try:
siglip_model = context.models.load_by_attrs(
name=SIGLIP_STARTER_MODEL_NAME,
base=BaseModelType.Any,
type=ModelType.SigLIP,
)
except UnknownModelException as e:
raise RuntimeError(
f"The SigLIP model required for FLUX Redux is not installed. Install '{SIGLIP_STARTER_MODEL_NAME}' "
"from the Starter Models tab."
) from e
with siglip_model.model_on_device() as (_, siglip_pipeline):
with context.models.load(self.siglip_model).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
@@ -93,19 +88,7 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
try:
redux_model = context.models.load_by_attrs(
name=FLUX_REDUX_STARTER_MODEL_NAME,
base=BaseModelType.Flux,
type=ModelType.FluxRedux,
)
except UnknownModelException as e:
raise RuntimeError(
f"The FLUX Redux model is not installed. Install the '{FLUX_REDUX_STARTER_MODEL_NAME}' model from the "
" Starter Models tab."
) from e
with redux_model.model_on_device() as (_, flux_redux):
with context.models.load(self.redux_model).model_on_device() as (_, flux_redux):
assert isinstance(flux_redux, FluxReduxModel)
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)

View File

@@ -44,6 +44,8 @@ import {
isFloatGeneratorFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxReduxModelFieldInputInstance,
isFluxReduxModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldCollectionInputInstance,
@@ -74,6 +76,8 @@ import {
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate,
isSigLipModelFieldInputInstance,
isSigLipModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
@@ -102,6 +106,7 @@ import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldI
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInputComponent';
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
@@ -111,6 +116,7 @@ import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComp
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SigLipModelFieldInputComponent from './inputs/SigLipModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
@@ -339,6 +345,20 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSigLipModelFieldInputTemplate(template)) {
if (!isSigLipModelFieldInputInstance(field)) {
return null;
}
return <SigLipModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxReduxModelFieldInputTemplate(template)) {
if (!isFluxReduxModelFieldInputInstance(field)) {
return null;
}
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isColorFieldInputTemplate(template)) {
if (!isColorFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,53 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
import type { FluxReduxModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const FluxReduxModelFieldInputComponent = (
props: FieldComponentProps<FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxReduxModels();
const _onChange = useCallback(
(value: FluxReduxModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxReduxModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
);
};
export default memo(FluxReduxModelFieldInputComponent);

View File

@@ -0,0 +1,53 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSigLipModels } from 'services/api/hooks/modelsByType';
import type { SigLipModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SigLipModelFieldInputComponent = (
props: FieldComponentProps<SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSigLipModels();
const _onChange = useCallback(
(value: SigLipModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSigLipModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
);
};
export default memo(SigLipModelFieldInputComponent);