mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add helpful error messages when FLUX Redux starter models are not installed.
This commit is contained in:
committed by
psychedelicious
parent
f1fde792ee
commit
82293ae3b2
@@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -34,6 +35,10 @@ class FluxReduxOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
|
||||
FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_redux",
|
||||
title="FLUX Redux",
|
||||
@@ -48,7 +53,6 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The FLUX Redux image prompt.")
|
||||
|
||||
# TODO(ryand): Add support for a mask.
|
||||
# TODO(ryand): Add helpful error messages that reference the starter models if a required model is not installed.
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
@@ -61,11 +65,19 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
|
||||
with context.models.load_by_attrs(
|
||||
name="SigLIP - google/siglip-so400m-patch14-384",
|
||||
base=BaseModelType.Any,
|
||||
type=ModelType.SigLIP,
|
||||
).model_on_device() as (_, siglip_pipeline):
|
||||
try:
|
||||
siglip_model = context.models.load_by_attrs(
|
||||
name=SIGLIP_STARTER_MODEL_NAME,
|
||||
base=BaseModelType.Any,
|
||||
type=ModelType.SigLIP,
|
||||
)
|
||||
except UnknownModelException as e:
|
||||
raise RuntimeError(
|
||||
f"The SigLIP model required for FLUX Redux is not installed. Install '{SIGLIP_STARTER_MODEL_NAME}' "
|
||||
"from the Starter Models tab."
|
||||
) from e
|
||||
|
||||
with siglip_model.model_on_device() as (_, siglip_pipeline):
|
||||
assert isinstance(siglip_pipeline, SigLipPipeline)
|
||||
return siglip_pipeline.encode_image(
|
||||
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
@@ -73,11 +85,19 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
|
||||
with context.models.load_by_attrs(
|
||||
name="FLUX Redux",
|
||||
base=BaseModelType.Flux,
|
||||
type=ModelType.FluxRedux,
|
||||
).model_on_device() as (_, flux_redux):
|
||||
try:
|
||||
redux_model = context.models.load_by_attrs(
|
||||
name=FLUX_REDUX_STARTER_MODEL_NAME,
|
||||
base=BaseModelType.Flux,
|
||||
type=ModelType.FluxRedux,
|
||||
)
|
||||
except UnknownModelException as e:
|
||||
raise RuntimeError(
|
||||
f"The FLUX Redux model is not installed. Install the '{FLUX_REDUX_STARTER_MODEL_NAME}' model from the "
|
||||
" Starter Models tab."
|
||||
) from e
|
||||
|
||||
with redux_model.model_on_device() as (_, flux_redux):
|
||||
assert isinstance(flux_redux, FluxReduxModel)
|
||||
dtype = next(flux_redux.parameters()).dtype
|
||||
encoded_x = encoded_x.to(dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user