diff --git a/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py b/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py index 0e99fac4a1..b5f24bc554 100644 --- a/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py +++ b/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py @@ -5,17 +5,11 @@ import torch from einops import rearrange +from invokeai.backend.flux.controlnet.zero_module import zero_module from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding -def _zero_module(module: torch.nn.Module) -> torch.nn.Module: - """Initialize the parameters of a module to zero.""" - for p in module.parameters(): - torch.nn.init.zeros_(p) - return module - - class XLabsControlNetFlux(torch.nn.Module): """A ControlNet model for FLUX. @@ -63,7 +57,7 @@ class XLabsControlNetFlux(torch.nn.Module): self.controlnet_blocks = torch.nn.ModuleList([]) for _ in range(controlnet_depth): controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size) - controlnet_block = _zero_module(controlnet_block) + controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True) self.input_hint_block = torch.nn.Sequential( @@ -81,7 +75,7 @@ class XLabsControlNetFlux(torch.nn.Module): torch.nn.SiLU(), torch.nn.Conv2d(16, 16, 3, padding=1, stride=2), torch.nn.SiLU(), - _zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)), + zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)), ) def forward( diff --git a/invokeai/backend/flux/controlnet/zero_module.py b/invokeai/backend/flux/controlnet/zero_module.py new file mode 100644 index 0000000000..53a21861a9 --- /dev/null +++ b/invokeai/backend/flux/controlnet/zero_module.py @@ -0,0 +1,12 @@ +from typing import TypeVar + +import torch + +T = TypeVar("T", bound=torch.nn.Module) + + +def zero_module(module: T) -> T: + """Initialize the parameters of a module to zero.""" + for p in module.parameters(): + torch.nn.init.zeros_(p) + return module