mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add model selecton fields to the FluxReduxInvocation.
This commit is contained in:
committed by
psychedelicious
parent
7b48ef2264
commit
b6b21dbcbf
@@ -57,6 +57,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
|
||||
@@ -16,15 +16,12 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -59,6 +56,16 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
description="The bool mask associated with this FLUX Redux image prompt. Excluded regions should be set to "
|
||||
"False, included regions should be set to True.",
|
||||
)
|
||||
redux_model: ModelIdentifierField = InputField(
|
||||
description="The FLUX Redux model to use.",
|
||||
title="FLUX Redux Model",
|
||||
ui_type=UIType.FluxReduxModel,
|
||||
)
|
||||
siglip_model: ModelIdentifierField = InputField(
|
||||
description="The SigLIP model to use.",
|
||||
title="SigLIP Model",
|
||||
ui_type=UIType.SigLipModel,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
@@ -73,19 +80,7 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
|
||||
try:
|
||||
siglip_model = context.models.load_by_attrs(
|
||||
name=SIGLIP_STARTER_MODEL_NAME,
|
||||
base=BaseModelType.Any,
|
||||
type=ModelType.SigLIP,
|
||||
)
|
||||
except UnknownModelException as e:
|
||||
raise RuntimeError(
|
||||
f"The SigLIP model required for FLUX Redux is not installed. Install '{SIGLIP_STARTER_MODEL_NAME}' "
|
||||
"from the Starter Models tab."
|
||||
) from e
|
||||
|
||||
with siglip_model.model_on_device() as (_, siglip_pipeline):
|
||||
with context.models.load(self.siglip_model).model_on_device() as (_, siglip_pipeline):
|
||||
assert isinstance(siglip_pipeline, SigLipPipeline)
|
||||
return siglip_pipeline.encode_image(
|
||||
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
@@ -93,19 +88,7 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
|
||||
try:
|
||||
redux_model = context.models.load_by_attrs(
|
||||
name=FLUX_REDUX_STARTER_MODEL_NAME,
|
||||
base=BaseModelType.Flux,
|
||||
type=ModelType.FluxRedux,
|
||||
)
|
||||
except UnknownModelException as e:
|
||||
raise RuntimeError(
|
||||
f"The FLUX Redux model is not installed. Install the '{FLUX_REDUX_STARTER_MODEL_NAME}' model from the "
|
||||
" Starter Models tab."
|
||||
) from e
|
||||
|
||||
with redux_model.model_on_device() as (_, flux_redux):
|
||||
with context.models.load(self.redux_model).model_on_device() as (_, flux_redux):
|
||||
assert isinstance(flux_redux, FluxReduxModel)
|
||||
dtype = next(flux_redux.parameters()).dtype
|
||||
encoded_x = encoded_x.to(dtype=dtype)
|
||||
|
||||
@@ -44,6 +44,8 @@ import {
|
||||
isFloatGeneratorFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxReduxModelFieldInputInstance,
|
||||
isFluxReduxModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
@@ -74,6 +76,8 @@ import {
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSigLipModelFieldInputInstance,
|
||||
isSigLipModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
@@ -102,6 +106,7 @@ import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldI
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInputComponent';
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
@@ -111,6 +116,7 @@ import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComp
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SigLipModelFieldInputComponent from './inputs/SigLipModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
||||
@@ -339,6 +345,20 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <SpandrelImageToImageModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSigLipModelFieldInputTemplate(template)) {
|
||||
if (!isSigLipModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <SigLipModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isFluxReduxModelFieldInputTemplate(template)) {
|
||||
if (!isFluxReduxModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <FluxReduxModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isColorFieldInputTemplate(template)) {
|
||||
if (!isColorFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldFluxReduxModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useFluxReduxModels } from 'services/api/hooks/modelsByType';
|
||||
import type { FluxReduxModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const FluxReduxModelFieldInputComponent = (
|
||||
props: FieldComponentProps<FluxReduxModelFieldInputInstance, FluxReduxModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useFluxReduxModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: FluxReduxModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxReduxModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxReduxModelFieldInputComponent);
|
||||
@@ -0,0 +1,53 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldSigLipModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSigLipModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SigLipModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SigLipModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SigLipModelFieldInputInstance, SigLipModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSigLipModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: SigLipModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSigLipModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SigLipModelFieldInputComponent);
|
||||
Reference in New Issue
Block a user