mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
WIP on updating FluxDenoise to support FLUX Fill.
This commit is contained in:
committed by
psychedelicious
parent
a913f0163d
commit
f13a07ba6a
@@ -267,8 +267,22 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if is_schnell and self.control_lora:
|
||||
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
||||
|
||||
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
# TODO(ryand): It's a bit confusing that we support inpainting via both FLUX Fill and masked image-to-image.
|
||||
# Think about ways to tidy this interface, or at least add clear error messages when incompatible inputs are
|
||||
# provided.
|
||||
|
||||
# Prepare the extra image conditioning tensor if either of the following are provided:
|
||||
# - FLUX structural control image
|
||||
# - FLUX Fill conditioning
|
||||
img_cond: torch.Tensor | None = None
|
||||
if self.control_lora is not None and self.fill_conditioning is not None:
|
||||
raise ValueError("Control LoRA and Fill conditioning cannot be used together.")
|
||||
elif self.control_lora is not None:
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
elif self.fill_conditioning is not None:
|
||||
img_cond = self._prep_flux_fill_img_cond(
|
||||
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
|
||||
)
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
@@ -672,6 +686,56 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
def _prep_flux_fill_img_cond(
|
||||
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare the FLUX Fill conditioning.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
|
||||
"""
|
||||
if self.fill_conditioning is None:
|
||||
return None
|
||||
|
||||
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a FLUX Fill conditioning.")
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
cond_img = np.array(cond_img)
|
||||
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
||||
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
|
||||
cond_img = cond_img.to(device=device, dtype=dtype)
|
||||
|
||||
# Load the mask and resize it to the target image size.
|
||||
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
mask = einops.rearrange(mask, "h w -> 1 1 h w")
|
||||
|
||||
# Prepare image conditioning.
|
||||
cond_img = cond_img * (1 - mask)
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
|
||||
cond_img = pack(cond_img)
|
||||
|
||||
# Prepare mask conditioning.
|
||||
mask = mask[:, 0, :, :]
|
||||
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
|
||||
mask = einops.rearrange(
|
||||
mask,
|
||||
"b (h ph) (w pw) -> b (ph pw) h w",
|
||||
ph=8,
|
||||
pw=8,
|
||||
)
|
||||
mask = pack(mask)
|
||||
|
||||
# Merge image and mask conditioning.
|
||||
img_cond = torch.cat((cond_img, mask), dim=-1)
|
||||
return img_cond
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user