mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Get FLUX Fill working. Note: To use FLUX Fill, set guidance to ~30.
This commit is contained in:
committed by
psychedelicious
parent
f13a07ba6a
commit
5ea3ec5cc8
@@ -291,7 +291,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
img_cond = pack(img_cond) if img_cond is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
@@ -663,13 +662,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
|
||||
if self.control_lora is None:
|
||||
return None
|
||||
|
||||
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor:
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
|
||||
|
||||
assert self.control_lora is not None
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.control_lora.img.image_name)
|
||||
cond_img = cond_img.convert("RGB")
|
||||
@@ -684,23 +682,24 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
return pack(img_cond)
|
||||
|
||||
def _prep_flux_fill_img_cond(
|
||||
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor | None:
|
||||
) -> torch.Tensor:
|
||||
"""Prepare the FLUX Fill conditioning.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
|
||||
"""
|
||||
if self.fill_conditioning is None:
|
||||
return None
|
||||
|
||||
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a FLUX Fill conditioning.")
|
||||
|
||||
assert self.fill_conditioning is not None
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
@@ -711,9 +710,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Load the mask and resize it to the target image size.
|
||||
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
|
||||
# We expect mask to be a bool tensor with shape [1, H, W].
|
||||
assert mask.dtype == torch.bool
|
||||
assert mask.dim() == 3
|
||||
assert mask.shape[0] == 1
|
||||
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
mask = einops.rearrange(mask, "h w -> 1 1 h w")
|
||||
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
|
||||
|
||||
# Prepare image conditioning.
|
||||
cond_img = cond_img * (1 - mask)
|
||||
@@ -724,12 +727,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Prepare mask conditioning.
|
||||
mask = mask[:, 0, :, :]
|
||||
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
|
||||
mask = einops.rearrange(
|
||||
mask,
|
||||
"b (h ph) (w pw) -> b (ph pw) h w",
|
||||
ph=8,
|
||||
pw=8,
|
||||
)
|
||||
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
||||
mask = pack(mask)
|
||||
|
||||
# Merge image and mask conditioning.
|
||||
|
||||
@@ -20,6 +20,7 @@ class ModelSpec:
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-dev-fill": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
@@ -68,4 +69,19 @@ params = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
"flux-dev-fill": FluxParams(
|
||||
in_channels=384,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -417,20 +417,22 @@ class ModelProbe(object):
|
||||
# TODO: Decide between dev/schnell
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
|
||||
# HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model
|
||||
# loading. When FLUX support was first added, it was decided that this was the easiest way to support
|
||||
# the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in
|
||||
# the future.
|
||||
if (
|
||||
"guidance_in.out_layer.weight" in state_dict
|
||||
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
||||
):
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-dev"
|
||||
if variant_type == ModelVariantType.Normal:
|
||||
config_file = "flux-dev"
|
||||
elif variant_type == ModelVariantType.Inpaint:
|
||||
config_file = "flux-dev-fill"
|
||||
else:
|
||||
raise ValueError(f"Unexpected FLUX variant type: {variant_type}")
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the discriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
|
||||
Reference in New Issue
Block a user