mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Remove gradient checkpointing from ControlNetFlux.
This commit is contained in:
@@ -21,8 +21,6 @@ class ControlNetFlux(nn.Module):
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, params: FluxParams, controlnet_depth=2):
|
||||
super().__init__()
|
||||
|
||||
@@ -64,7 +62,6 @@ class ControlNetFlux(nn.Module):
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.gradient_checkpointing = False
|
||||
self.input_hint_block = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
@@ -83,10 +80,6 @@ class ControlNetFlux(nn.Module):
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
@@ -121,27 +114,7 @@ class ControlNetFlux(nn.Module):
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.double_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
|
||||
Reference in New Issue
Block a user