WIP on updating FluxDenoise to support FLUX Fill.

This commit is contained in:
Ryan Dick
2025-03-12 19:41:30 +00:00
committed by psychedelicious
parent a913f0163d
commit f13a07ba6a

View File

@@ -267,8 +267,22 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if is_schnell and self.control_lora:
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
img_cond = self._prep_structural_control_img_cond(context)
# TODO(ryand): It's a bit confusing that we support inpainting via both FLUX Fill and masked image-to-image.
# Think about ways to tidy this interface, or at least add clear error messages when incompatible inputs are
# provided.
# Prepare the extra image conditioning tensor if either of the following are provided:
# - FLUX structural control image
# - FLUX Fill conditioning
img_cond: torch.Tensor | None = None
if self.control_lora is not None and self.fill_conditioning is not None:
raise ValueError("Control LoRA and Fill conditioning cannot be used together.")
elif self.control_lora is not None:
img_cond = self._prep_structural_control_img_cond(context)
elif self.fill_conditioning is not None:
img_cond = self._prep_flux_fill_img_cond(
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
)
inpaint_mask = self._prep_inpaint_mask(context, x)
@@ -672,6 +686,56 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.controlnet_vae.vae)
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
def _prep_flux_fill_img_cond(
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
) -> torch.Tensor | None:
"""Prepare the FLUX Fill conditioning.
This logic is based on:
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
"""
if self.fill_conditioning is None:
return None
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a FLUX Fill conditioning.")
# Load the conditioning image and resize it to the target image size.
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
cond_img = np.array(cond_img)
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
cond_img = cond_img.to(device=device, dtype=dtype)
# Load the mask and resize it to the target image size.
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
assert mask.dtype == torch.bool
mask = mask.to(device=device, dtype=dtype)
mask = einops.rearrange(mask, "h w -> 1 1 h w")
# Prepare image conditioning.
cond_img = cond_img * (1 - mask)
vae_info = context.models.load(self.controlnet_vae.vae)
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
cond_img = pack(cond_img)
# Prepare mask conditioning.
mask = mask[:, 0, :, :]
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
mask = einops.rearrange(
mask,
"b (h ph) (w pw) -> b (ph pw) h w",
ph=8,
pw=8,
)
mask = pack(mask)
# Merge image and mask conditioning.
img_cond = torch.cat((cond_img, mask), dim=-1)
return img_cond
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
if self.ip_adapter is None:
return []