Fix type errors and imporve docs for ControlNetFlux.

This commit is contained in:
Ryan Dick
2024-10-02 18:02:58 +00:00
committed by Kent Keirsey
parent 69c0d7dcc9
commit 0b84f567f1

View File

@@ -10,18 +10,23 @@ from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
def zero_module(module):
def _zero_module(module: torch.nn.Module) -> torch.nn.Module:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
nn.init.zeros_(p)
return module
class ControlNetFlux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""A ControlNet model for FLUX.
The architecture is very similar to the base FLUX model, with the following differences:
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
"""
def __init__(self, params: FluxParams, controlnet_depth=2):
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
super().__init__()
self.params = params
@@ -55,11 +60,11 @@ class ControlNetFlux(nn.Module):
]
)
# add ControlNet blocks
# Add ControlNet blocks.
self.controlnet_blocks = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
controlnet_block = _zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.input_hint_block = nn.Sequential(
@@ -77,7 +82,7 @@ class ControlNetFlux(nn.Module):
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
_zero_module(nn.Conv2d(16, 16, 3, padding=1)),
)
def forward(
@@ -90,7 +95,7 @@ class ControlNetFlux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
) -> list[Tensor]:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -111,15 +116,15 @@ class ControlNetFlux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
block_res_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
block_res_samples.append(img)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=False):
controlnet_block_res_samples: list[torch.Tensor] = []
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_block_res_samples.append(block_res_sample)
return controlnet_block_res_samples