mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 21:48:01 -05:00
Compare commits
10 Commits
controlnet
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7ded2e1f8 | ||
|
|
acfed395fb | ||
|
|
cf32c7b370 | ||
|
|
9d897b46cb | ||
|
|
9ad64bb45b | ||
|
|
4114f647dd | ||
|
|
3632a6cb18 | ||
|
|
f90e05ca7c | ||
|
|
cfc7148444 | ||
|
|
e5b578eed2 |
@@ -6,6 +6,7 @@ import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -19,6 +20,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
|
||||
from invokeai.backend.flux.controlnet_extension import ControlNetExtension
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -44,7 +47,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.0.0",
|
||||
version="3.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -87,6 +90,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
controlnet: ControlField | list[ControlField] | None = InputField(
|
||||
default=None, input=Input.Connection, description="ControlNet models."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -167,8 +173,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
@@ -231,6 +237,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
# Prepare ControlNet extensions.
|
||||
controlnet_extensions = self._prep_controlnet_extensions(
|
||||
context=context,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_h,
|
||||
latent_width=latent_w,
|
||||
dtype=inference_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
@@ -242,6 +258,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -288,6 +305,50 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _prep_controlnet_extensions(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
exit_stack: ExitStack,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[ControlNetExtension] | None:
|
||||
# Normalize the controlnet input to list[ControlField].
|
||||
controlnets: list[ControlField]
|
||||
if self.controlnet is None:
|
||||
return None
|
||||
elif isinstance(self.controlnet, ControlField):
|
||||
controlnets = [self.controlnet]
|
||||
elif isinstance(self.controlnet, list):
|
||||
controlnets = self.controlnet
|
||||
else:
|
||||
raise ValueError(f"Unsupported controlnet type: {type(self.controlnet)}")
|
||||
|
||||
controlnet_extensions: list[ControlNetExtension] = []
|
||||
for controlnet in controlnets:
|
||||
model = exit_stack.enter_context(context.models.load(controlnet.control_model))
|
||||
assert isinstance(model, ControlNetFlux)
|
||||
image = context.images.get_pil(controlnet.image.image_name)
|
||||
|
||||
controlnet_extensions.append(
|
||||
ControlNetExtension.from_controlnet_image(
|
||||
model=model,
|
||||
controlnet_image=image,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
control_mode=controlnet.control_mode,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
|
||||
0
invokeai/backend/flux/controlnet/__init__.py
Normal file
0
invokeai/backend/flux/controlnet/__init__.py
Normal file
130
invokeai/backend/flux/controlnet/controlnet_flux.py
Normal file
130
invokeai/backend/flux/controlnet/controlnet_flux.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# This file was initially based on:
|
||||
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
|
||||
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
|
||||
|
||||
|
||||
def _zero_module(module: torch.nn.Module) -> torch.nn.Module:
|
||||
"""Initialize the parameters of a module to zero."""
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
class ControlNetFlux(nn.Module):
|
||||
"""A ControlNet model for FLUX.
|
||||
|
||||
The architecture is very similar to the base FLUX model, with the following differences:
|
||||
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
|
||||
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
|
||||
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(controlnet_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Add ControlNet blocks.
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = _zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.input_hint_block = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
_zero_module(nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> list[Tensor]:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples: list[torch.Tensor] = []
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples.append(img)
|
||||
|
||||
controlnet_block_res_samples: list[torch.Tensor] = []
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples.append(block_res_sample)
|
||||
|
||||
return controlnet_block_res_samples
|
||||
103
invokeai/backend/flux/controlnet_extension.py
Normal file
103
invokeai/backend/flux/controlnet_extension.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
|
||||
|
||||
|
||||
class ControlNetExtension:
|
||||
def __init__(
|
||||
self,
|
||||
model: ControlNetFlux,
|
||||
controlnet_cond: torch.Tensor,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
self._model = model
|
||||
# _controlnet_cond is the control image passed to the ControlNet model.
|
||||
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
|
||||
self._controlnet_cond = controlnet_cond
|
||||
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
@classmethod
|
||||
def from_controlnet_image(
|
||||
cls,
|
||||
model: ControlNetFlux,
|
||||
controlnet_image: Image,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
control_mode: CONTROLNET_MODE_VALUES,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_cond = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode=control_mode,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Map pixel values from [0, 1] to [-1, 1].
|
||||
controlnet_cond = controlnet_cond * 2 - 1
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> list[torch.Tensor] | None:
|
||||
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
|
||||
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
|
||||
if timestep_index < first_step or timestep_index > last_step:
|
||||
return
|
||||
weight = self._weight
|
||||
|
||||
controlnet_block_res_samples = self._model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
y=y,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
# Apply weight to the residuals.
|
||||
for block_res_sample in controlnet_block_res_samples:
|
||||
block_res_sample *= weight
|
||||
|
||||
return controlnet_block_res_samples
|
||||
@@ -3,6 +3,7 @@ from typing import Callable
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.controlnet_extension import ControlNetExtension
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
@@ -21,6 +22,7 @@ def denoise(
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
controlnet_extensions: list[ControlNetExtension] | None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -38,6 +40,25 @@ def denoise(
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run ControlNet models.
|
||||
# controlnet_block_residuals[i][j] is the residual of the j-th block of the i-th ControlNet model.
|
||||
controlnet_block_residuals: list[list[torch.Tensor] | None] = []
|
||||
for controlnet_extension in controlnet_extensions or []:
|
||||
controlnet_block_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step - 1,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
)
|
||||
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -46,6 +67,7 @@ def denoise(
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
controlnet_block_residuals=controlnet_block_residuals,
|
||||
)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
|
||||
@@ -88,6 +88,7 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
controlnet_block_residuals: list[list[Tensor] | None] | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -105,9 +106,15 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
# Apply ControlNet residuals.
|
||||
if controlnet_block_residuals is not None:
|
||||
for single_controlnet_block_residuals in controlnet_block_residuals:
|
||||
if single_controlnet_block_residuals:
|
||||
img += single_controlnet_block_residuals[block_index % len(single_controlnet_block_residuals)]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
@@ -8,17 +8,36 @@ from diffusers import ControlNetModel
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ControlNetCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux import ControlNetFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
@@ -24,6 +25,7 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
@@ -293,3 +295,24 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class FluxControlnetModel(ModelLoader):
|
||||
"""Class to load FLUX ControlNet models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, ControlNetCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
# HACK(ryand): Is it safe to assume dev here?
|
||||
model = ControlNetFlux(params["flux-dev"])
|
||||
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -255,7 +255,19 @@ class ModelProbe(object):
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return ModelType.LoRA
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
elif key.startswith(
|
||||
(
|
||||
"controlnet",
|
||||
"control_model",
|
||||
"input_blocks",
|
||||
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
|
||||
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
|
||||
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
|
||||
# delicate.
|
||||
"controlnet_blocks",
|
||||
)
|
||||
):
|
||||
return ModelType.ControlNet
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
@@ -623,6 +635,12 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
|
||||
if "double_blocks.0.img_attn.qkv.weight" in checkpoint:
|
||||
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
|
||||
# get_format()?
|
||||
return BaseModelType.Flux
|
||||
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"controlnet_mid_block.bias",
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# State dict keys for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
state_dict_keys = [
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"controlnet_blocks.1.bias",
|
||||
"controlnet_blocks.1.weight",
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"double_blocks.0.img_attn.norm.query_norm.scale",
|
||||
"double_blocks.0.img_attn.proj.bias",
|
||||
"double_blocks.0.img_attn.proj.weight",
|
||||
"double_blocks.0.img_attn.qkv.bias",
|
||||
"double_blocks.0.img_attn.qkv.weight",
|
||||
"double_blocks.0.img_mlp.0.bias",
|
||||
"double_blocks.0.img_mlp.0.weight",
|
||||
"double_blocks.0.img_mlp.2.bias",
|
||||
"double_blocks.0.img_mlp.2.weight",
|
||||
"double_blocks.0.img_mod.lin.bias",
|
||||
"double_blocks.0.img_mod.lin.weight",
|
||||
"double_blocks.0.txt_attn.norm.key_norm.scale",
|
||||
"double_blocks.0.txt_attn.norm.query_norm.scale",
|
||||
"double_blocks.0.txt_attn.proj.bias",
|
||||
"double_blocks.0.txt_attn.proj.weight",
|
||||
"double_blocks.0.txt_attn.qkv.bias",
|
||||
"double_blocks.0.txt_attn.qkv.weight",
|
||||
"double_blocks.0.txt_mlp.0.bias",
|
||||
"double_blocks.0.txt_mlp.0.weight",
|
||||
"double_blocks.0.txt_mlp.2.bias",
|
||||
"double_blocks.0.txt_mlp.2.weight",
|
||||
"double_blocks.0.txt_mod.lin.bias",
|
||||
"double_blocks.0.txt_mod.lin.weight",
|
||||
"double_blocks.1.img_attn.norm.key_norm.scale",
|
||||
"double_blocks.1.img_attn.norm.query_norm.scale",
|
||||
"double_blocks.1.img_attn.proj.bias",
|
||||
"double_blocks.1.img_attn.proj.weight",
|
||||
"double_blocks.1.img_attn.qkv.bias",
|
||||
"double_blocks.1.img_attn.qkv.weight",
|
||||
"double_blocks.1.img_mlp.0.bias",
|
||||
"double_blocks.1.img_mlp.0.weight",
|
||||
"double_blocks.1.img_mlp.2.bias",
|
||||
"double_blocks.1.img_mlp.2.weight",
|
||||
"double_blocks.1.img_mod.lin.bias",
|
||||
"double_blocks.1.img_mod.lin.weight",
|
||||
"double_blocks.1.txt_attn.norm.key_norm.scale",
|
||||
"double_blocks.1.txt_attn.norm.query_norm.scale",
|
||||
"double_blocks.1.txt_attn.proj.bias",
|
||||
"double_blocks.1.txt_attn.proj.weight",
|
||||
"double_blocks.1.txt_attn.qkv.bias",
|
||||
"double_blocks.1.txt_attn.qkv.weight",
|
||||
"double_blocks.1.txt_mlp.0.bias",
|
||||
"double_blocks.1.txt_mlp.0.weight",
|
||||
"double_blocks.1.txt_mlp.2.bias",
|
||||
"double_blocks.1.txt_mlp.2.weight",
|
||||
"double_blocks.1.txt_mod.lin.bias",
|
||||
"double_blocks.1.txt_mod.lin.weight",
|
||||
"guidance_in.in_layer.bias",
|
||||
"guidance_in.in_layer.weight",
|
||||
"guidance_in.out_layer.bias",
|
||||
"guidance_in.out_layer.weight",
|
||||
"img_in.bias",
|
||||
"img_in.weight",
|
||||
"input_hint_block.0.bias",
|
||||
"input_hint_block.0.weight",
|
||||
"input_hint_block.10.bias",
|
||||
"input_hint_block.10.weight",
|
||||
"input_hint_block.12.bias",
|
||||
"input_hint_block.12.weight",
|
||||
"input_hint_block.14.bias",
|
||||
"input_hint_block.14.weight",
|
||||
"input_hint_block.2.bias",
|
||||
"input_hint_block.2.weight",
|
||||
"input_hint_block.4.bias",
|
||||
"input_hint_block.4.weight",
|
||||
"input_hint_block.6.bias",
|
||||
"input_hint_block.6.weight",
|
||||
"input_hint_block.8.bias",
|
||||
"input_hint_block.8.weight",
|
||||
"pos_embed_input.bias",
|
||||
"pos_embed_input.weight",
|
||||
"time_in.in_layer.bias",
|
||||
"time_in.in_layer.weight",
|
||||
"time_in.out_layer.bias",
|
||||
"time_in.out_layer.weight",
|
||||
"txt_in.bias",
|
||||
"txt_in.weight",
|
||||
"vector_in.in_layer.bias",
|
||||
"vector_in.in_layer.weight",
|
||||
"vector_in.out_layer.bias",
|
||||
"vector_in.out_layer.weight",
|
||||
]
|
||||
Reference in New Issue
Block a user