Get FLUX Fill working. Note: To use FLUX Fill, set guidance to ~30.

This commit is contained in:
Ryan Dick
2025-03-13 16:38:18 +00:00
committed by psychedelicious
parent f13a07ba6a
commit 5ea3ec5cc8
3 changed files with 42 additions and 26 deletions

View File

@@ -291,7 +291,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
img_cond = pack(img_cond) if img_cond is not None else None
noise = pack(noise)
x = pack(x)
@@ -663,13 +662,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return controlnet_extensions
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
if self.control_lora is None:
return None
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor:
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
assert self.control_lora is not None
# Load the conditioning image and resize it to the target image size.
cond_img = context.images.get_pil(self.control_lora.img.image_name)
cond_img = cond_img.convert("RGB")
@@ -684,23 +682,24 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
vae_info = context.models.load(self.controlnet_vae.vae)
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
return pack(img_cond)
def _prep_flux_fill_img_cond(
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
) -> torch.Tensor | None:
) -> torch.Tensor:
"""Prepare the FLUX Fill conditioning.
This logic is based on:
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
"""
if self.fill_conditioning is None:
return None
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a FLUX Fill conditioning.")
assert self.fill_conditioning is not None
# Load the conditioning image and resize it to the target image size.
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
@@ -711,9 +710,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Load the mask and resize it to the target image size.
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
# We expect mask to be a bool tensor with shape [1, H, W].
assert mask.dtype == torch.bool
assert mask.dim() == 3
assert mask.shape[0] == 1
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
mask = mask.to(device=device, dtype=dtype)
mask = einops.rearrange(mask, "h w -> 1 1 h w")
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
# Prepare image conditioning.
cond_img = cond_img * (1 - mask)
@@ -724,12 +727,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Prepare mask conditioning.
mask = mask[:, 0, :, :]
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
mask = einops.rearrange(
mask,
"b (h ph) (w pw) -> b (ph pw) h w",
ph=8,
pw=8,
)
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
mask = pack(mask)
# Merge image and mask conditioning.

View File

@@ -20,6 +20,7 @@ class ModelSpec:
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,
"flux-schnell": 256,
}
@@ -68,4 +69,19 @@ params = {
qkv_bias=True,
guidance_embed=False,
),
"flux-dev-fill": FluxParams(
in_channels=384,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
}

View File

@@ -417,20 +417,22 @@ class ModelProbe(object):
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
# HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model
# loading. When FLUX support was first added, it was decided that this was the easiest way to support
# the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in
# the future.
if (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
):
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
if variant_type == ModelVariantType.Normal:
config_file = "flux-dev"
elif variant_type == ModelVariantType.Inpaint:
config_file = "flux-dev-fill"
else:
raise ValueError(f"Unexpected FLUX variant type: {variant_type}")
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the discriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]