Add helpful error messages when FLUX Redux starter models are not installed.

This commit is contained in:
Ryan Dick
2025-02-28 18:33:51 +00:00
committed by psychedelicious
parent f1fde792ee
commit 82293ae3b2

View File

@@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import (
@@ -34,6 +35,10 @@ class FluxReduxOutput(BaseInvocationOutput):
)
SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"
@invocation(
"flux_redux",
title="FLUX Redux",
@@ -48,7 +53,6 @@ class FluxReduxInvocation(BaseInvocation):
image: ImageField = InputField(description="The FLUX Redux image prompt.")
# TODO(ryand): Add support for a mask.
# TODO(ryand): Add helpful error messages that reference the starter models if a required model is not installed.
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
@@ -61,11 +65,19 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
with context.models.load_by_attrs(
name="SigLIP - google/siglip-so400m-patch14-384",
base=BaseModelType.Any,
type=ModelType.SigLIP,
).model_on_device() as (_, siglip_pipeline):
try:
siglip_model = context.models.load_by_attrs(
name=SIGLIP_STARTER_MODEL_NAME,
base=BaseModelType.Any,
type=ModelType.SigLIP,
)
except UnknownModelException as e:
raise RuntimeError(
f"The SigLIP model required for FLUX Redux is not installed. Install '{SIGLIP_STARTER_MODEL_NAME}' "
"from the Starter Models tab."
) from e
with siglip_model.model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
@@ -73,11 +85,19 @@ class FluxReduxInvocation(BaseInvocation):
@torch.no_grad()
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
with context.models.load_by_attrs(
name="FLUX Redux",
base=BaseModelType.Flux,
type=ModelType.FluxRedux,
).model_on_device() as (_, flux_redux):
try:
redux_model = context.models.load_by_attrs(
name=FLUX_REDUX_STARTER_MODEL_NAME,
base=BaseModelType.Flux,
type=ModelType.FluxRedux,
)
except UnknownModelException as e:
raise RuntimeError(
f"The FLUX Redux model is not installed. Install the '{FLUX_REDUX_STARTER_MODEL_NAME}' model from the "
" Starter Models tab."
) from e
with redux_model.model_on_device() as (_, flux_redux):
assert isinstance(flux_redux, FluxReduxModel)
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)