feat(nodes): remove siglip from flux_redux, dl it jit when needed if we cannot find it

This follows the same pattern for IP Adapter w/ its CLIP Vision model. The SigLIP model is unlikely to ever change and we don't want to force the user to select it anywhere. Hardcoding it is safe and makes the UX much nicer.

The alternative is a model dropdown that will likely only ever have one valid choice in it.
This commit is contained in:
psychedelicious
2025-03-07 16:01:48 +10:00
parent e35537e60a
commit 57533657f9

View File

@@ -20,8 +20,11 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.backend.model_manager.starter_models import siglip
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice
@@ -35,16 +38,12 @@ class FluxReduxOutput(BaseInvocationOutput):
)
SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"
@invocation(
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
version="2.0.0",
classification=Classification.Prototype,
)
class FluxReduxInvocation(BaseInvocation):
@@ -61,11 +60,6 @@ class FluxReduxInvocation(BaseInvocation):
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
)
siglip_model: ModelIdentifierField = InputField(
description="The SigLIP model to use.",
title="SigLIP Model",
ui_type=UIType.SigLipModel,
)
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
@@ -80,7 +74,8 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
with context.models.load(self.siglip_model).model_on_device() as (_, siglip_pipeline):
siglip_model_config = self._get_siglip_model(context)
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
@@ -93,3 +88,32 @@ class FluxReduxInvocation(BaseInvocation):
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)
return flux_redux(encoded_x)
def _get_siglip_model(self, context: InvocationContext) -> AnyModelConfig:
siglip_models = context.models.search_by_attrs(name=siglip.name, base=BaseModelType.Any, type=ModelType.SigLIP)
if not len(siglip_models) > 0:
context.logger.warning(
f"The SigLIP model required by FLUX Redux ({siglip.name}) is not installed. Downloading and installing now. This may take a while."
)
# TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
config_overrides = ModelRecordChanges(name=siglip.name, type=ModelType.SigLIP)
# Queue the job
job = context._services.model_manager.install.heuristic_import(siglip.source, config=config_overrides)
# Wait for up to 10 minutes - model is ~3.5GB
context._services.model_manager.install.wait_for_job(job, timeout=600)
siglip_models = context.models.search_by_attrs(
name=siglip.name,
base=BaseModelType.Any,
type=ModelType.SigLIP,
)
if len(siglip_models) == 0:
context.logger.error("Error while fetching SigLIP for FLUX Redux")
assert len(siglip_models) == 1
return siglip_models[0]