diff --git a/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py b/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py index 00876394ee..0e99fac4a1 100644 --- a/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py +++ b/invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py @@ -4,7 +4,6 @@ import torch from einops import rearrange -from torch import Tensor, nn from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding @@ -13,11 +12,11 @@ from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLP def _zero_module(module: torch.nn.Module) -> torch.nn.Module: """Initialize the parameters of a module to zero.""" for p in module.parameters(): - nn.init.zeros_(p) + torch.nn.init.zeros_(p) return module -class XLabsControlNetFlux(nn.Module): +class XLabsControlNetFlux(torch.nn.Module): """A ControlNet model for FLUX. The architecture is very similar to the base FLUX model, with the following differences: @@ -40,15 +39,15 @@ class XLabsControlNetFlux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.img_in = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else torch.nn.Identity() ) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.txt_in = torch.nn.Linear(params.context_in_dim, self.hidden_size) - self.double_blocks = nn.ModuleList( + self.double_blocks = torch.nn.ModuleList( [ DoubleStreamBlock( self.hidden_size, @@ -61,41 +60,41 @@ class XLabsControlNetFlux(nn.Module): ) # Add ControlNet blocks. - self.controlnet_blocks = nn.ModuleList([]) + self.controlnet_blocks = torch.nn.ModuleList([]) for _ in range(controlnet_depth): - controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = _zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) - self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.input_hint_block = nn.Sequential( - nn.Conv2d(3, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1, stride=2), - nn.SiLU(), - _zero_module(nn.Conv2d(16, 16, 3, padding=1)), + self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.input_hint_block = torch.nn.Sequential( + torch.nn.Conv2d(3, 16, 3, padding=1), + torch.nn.SiLU(), + torch.nn.Conv2d(16, 16, 3, padding=1), + torch.nn.SiLU(), + torch.nn.Conv2d(16, 16, 3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(16, 16, 3, padding=1), + torch.nn.SiLU(), + torch.nn.Conv2d(16, 16, 3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(16, 16, 3, padding=1), + torch.nn.SiLU(), + torch.nn.Conv2d(16, 16, 3, padding=1, stride=2), + torch.nn.SiLU(), + _zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)), ) def forward( self, - img: Tensor, - img_ids: Tensor, - controlnet_cond: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor | None = None, - ) -> list[Tensor]: + img: torch.Tensor, + img_ids: torch.Tensor, + controlnet_cond: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + timesteps: torch.Tensor, + y: torch.Tensor, + guidance: torch.Tensor | None = None, + ) -> list[torch.Tensor]: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.")