Improve typing of zero_module().

This commit is contained in:
Ryan Dick
2024-10-04 18:22:10 +00:00
committed by Kent Keirsey
parent 83f4700f5a
commit 7562ea48dc
2 changed files with 15 additions and 9 deletions

View File

@@ -5,17 +5,11 @@
import torch
from einops import rearrange
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
def _zero_module(module: torch.nn.Module) -> torch.nn.Module:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
torch.nn.init.zeros_(p)
return module
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.
@@ -63,7 +57,7 @@ class XLabsControlNetFlux(torch.nn.Module):
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = _zero_module(controlnet_block)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.input_hint_block = torch.nn.Sequential(
@@ -81,7 +75,7 @@ class XLabsControlNetFlux(torch.nn.Module):
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
_zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
)
def forward(

View File

@@ -0,0 +1,12 @@
from typing import TypeVar
import torch
T = TypeVar("T", bound=torch.nn.Module)
def zero_module(module: T) -> T:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
torch.nn.init.zeros_(p)
return module