Get FLUX Redux working: model loading and inference.

This commit is contained in:
Ryan Dick
2025-02-27 23:02:11 +00:00
committed by psychedelicious
parent e82393f7ed
commit f1fde792ee
10 changed files with 300 additions and 9 deletions

View File

@@ -202,6 +202,7 @@ class FieldDescriptions:
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
class ImageField(BaseModel):
@@ -260,6 +261,12 @@ class FluxConditioningField(BaseModel):
)
class FluxReduxConditioningField(BaseModel):
"""A FLUX Redux conditioning tensor primitive value"""
tensor_name: str = Field(description="The name of the conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
@@ -46,7 +47,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
@@ -103,6 +104,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
redux_conditioning: FluxReduxConditioningField | list[FluxReduxConditioningField] | None = InputField(
default=None,
description="FLUX Redux conditioning tensor.",
input=Input.Connection,
)
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale_start_step: int = InputField(
default=0,
@@ -190,11 +196,21 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
redux_conditionings: list[FluxReduxConditioning] = self._load_redux_conditioning(
context=context,
redux_cond_field=self.redux_conditioning,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
pos_text_conditionings, img_seq_len=packed_h * packed_w
text_conditioning=pos_text_conditionings,
redux_conditioning=redux_conditionings,
img_seq_len=packed_h * packed_w,
)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
RegionalPromptingExtension.from_text_conditioning(
text_conditioning=neg_text_conditionings, redux_conditioning=[], img_seq_len=packed_h * packed_w
)
if neg_text_conditionings
else None
)
@@ -400,6 +416,29 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return text_conditionings
def _load_redux_conditioning(
self,
context: InvocationContext,
redux_cond_field: FluxReduxConditioningField | list[FluxReduxConditioningField] | None,
device: torch.device,
dtype: torch.dtype,
) -> list[FluxReduxConditioning]:
# Normalize to a list of FluxReduxConditioningFields.
if redux_cond_field is None:
return []
redux_cond_list = (
[redux_cond_field] if isinstance(redux_cond_field, FluxReduxConditioningField) else redux_cond_field
)
redux_conditionings: list[FluxReduxConditioning] = []
for redux_cond_field in redux_cond_list:
redux_cond_data = context.tensors.load(redux_cond_field.tensor_name)
redux_cond_data.to(device=device, dtype=dtype)
redux_conditionings.append(FluxReduxConditioning(redux_embeddings=redux_cond_data))
return redux_conditionings
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int

View File

@@ -0,0 +1,84 @@
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxReduxConditioningField,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelType,
)
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice
@invocation_output("flux_redux_output")
class FluxReduxOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Redux invocation."""
redux_cond: FluxReduxConditioningField = OutputField(
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
)
@invocation(
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxReduxInvocation(BaseInvocation):
"""Runs a FLUX Redux model to generate a conditioning tensor."""
image: ImageField = InputField(description="The FLUX Redux image prompt.")
# TODO(ryand): Add support for a mask.
# TODO(ryand): Add helpful error messages that reference the starter models if a required model is not installed.
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
encoded_x = self._siglip_encode(context, image)
redux_conditioning = self._flux_redux_encode(context, encoded_x)
tensor_name = context.tensors.save(redux_conditioning)
return FluxReduxOutput(redux_cond=FluxReduxConditioningField(tensor_name=tensor_name))
@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
with context.models.load_by_attrs(
name="SigLIP - google/siglip-so400m-patch14-384",
base=BaseModelType.Any,
type=ModelType.SigLIP,
).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
@torch.no_grad()
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
with context.models.load_by_attrs(
name="FLUX Redux",
base=BaseModelType.Flux,
type=ModelType.FluxRedux,
).model_on_device() as (_, flux_redux):
assert isinstance(flux_redux, FluxReduxModel)
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)
return flux_redux(encoded_x)

View File

@@ -3,7 +3,11 @@ from typing import Optional
import torch
import torchvision
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.flux.text_conditioning import (
FluxReduxConditioning,
FluxRegionalTextConditioning,
FluxTextConditioning,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
@@ -32,14 +36,19 @@ class RegionalPromptingExtension:
return order[block_index % len(order)]
@classmethod
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
def from_text_conditioning(
cls,
text_conditioning: list[FluxTextConditioning],
redux_conditioning: list[FluxReduxConditioning],
img_seq_len: int,
):
"""Create a RegionalPromptingExtension from a list of text conditionings.
Args:
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
"""
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning, redux_conditioning)
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
regional_text_conditioning, img_seq_len
)
@@ -202,6 +211,7 @@ class RegionalPromptingExtension:
def _concat_regional_text_conditioning(
cls,
text_conditionings: list[FluxTextConditioning],
redux_conditionings: list[FluxReduxConditioning],
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
@@ -217,18 +227,27 @@ class RegionalPromptingExtension:
global_clip_embedding = text_conditioning.clip_embeddings
break
# Handle T5 text embeddings.
cur_t5_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
# Handle Redux embeddings.
for redux_conditioning in redux_conditionings:
concat_t5_embeddings.append(redux_conditioning.redux_embeddings)
concat_t5_embedding_ranges.append(
Range(
start=cur_t5_embedding_len, end=cur_t5_embedding_len + redux_conditioning.redux_embeddings.shape[1]
)
)
image_masks.append(None)
cur_t5_embedding_len += redux_conditioning.redux_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
# Initialize the txt_ids tensor.

View File

@@ -0,0 +1,17 @@
import torch
# This model definition is based on:
# https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/modules/image_embedders.py#L66
class FluxReduxModel(torch.nn.Module):
def __init__(self, redux_dim: int = 1152, txt_in_features: int = 4096) -> None:
super().__init__()
self.redux_dim = redux_dim
self.redux_up = torch.nn.Linear(redux_dim, txt_in_features * 3)
self.redux_down = torch.nn.Linear(txt_in_features * 3, txt_in_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.redux_down(torch.nn.functional.silu(self.redux_up(x)))

View File

@@ -13,6 +13,11 @@ class FluxTextConditioning:
mask: torch.Tensor | None
@dataclass
class FluxReduxConditioning:
redux_embeddings: torch.Tensor
@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.

View File

@@ -25,6 +25,7 @@ from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
)
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
@@ -39,6 +40,7 @@ from invokeai.backend.model_manager.config import (
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
FluxReduxConfig,
IPAdapterCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
@@ -391,3 +393,25 @@ class FluxIpAdapterModel(ModelLoader):
model.load_xlabs_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.FluxRedux, format=ModelFormat.Checkpoint)
class FluxReduxModelLoader(ModelLoader):
"""Class to load FLUX Redux models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, FluxReduxConfig):
raise ValueError(f"Unexpected model config type: {type(config)}.")
sd = load_file(Path(config.path))
with accelerate.init_empty_weights():
model = FluxReduxModel()
model.load_state_dict(sd, assign=True)
model.to(dtype=torch.bfloat16)
return model

View File

@@ -0,0 +1,32 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.SigLIP, format=ModelFormat.Diffusers)
class SigLIPModelLoader(ModelLoader):
"""Class for loading SigLIP models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
model_path = Path(config.path)
model = SigLipPipeline.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
return model

View File

@@ -0,0 +1,36 @@
from pathlib import Path
from typing import Optional
import torch
from PIL import Image
from transformers import SiglipImageProcessor, SiglipVisionModel
from invokeai.backend.raw_model import RawModel
class SigLipPipeline(RawModel):
"""A wrapper for a SigLIP model + processor."""
def __init__(
self,
siglip_processor: SiglipImageProcessor,
siglip_model: SiglipVisionModel,
):
self._siglip_processor = siglip_processor
self._siglip_model = siglip_model
@classmethod
def load_from_path(cls, path: str | Path):
siglip_model = SiglipVisionModel.from_pretrained(path, local_files_only=True)
assert isinstance(siglip_model, SiglipVisionModel)
siglip_processor = SiglipImageProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(siglip_processor, SiglipImageProcessor)
return cls(siglip_processor, siglip_model)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._siglip_model.to(device=device, dtype=dtype)
def encode_image(self, x: Image.Image, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
imgs = self._siglip_processor.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
encoded_x = self._siglip_model(**imgs.to(device=device, dtype=dtype)).last_hidden_state
return encoded_x