Remove gradient checkpointing from ControlNetFlux.

This commit is contained in:
Ryan Dick
2024-10-02 17:49:53 +00:00
committed by Kent Keirsey
parent 5307248fcf
commit 69c0d7dcc9

View File

@@ -21,8 +21,6 @@ class ControlNetFlux(nn.Module):
Transformer model for flow matching on sequences.
"""
_supports_gradient_checkpointing = True
def __init__(self, params: FluxParams, controlnet_depth=2):
super().__init__()
@@ -64,7 +62,6 @@ class ControlNetFlux(nn.Module):
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
@@ -83,10 +80,6 @@ class ControlNetFlux(nn.Module):
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
img: Tensor,
@@ -121,27 +114,7 @@ class ControlNetFlux(nn.Module):
block_res_samples = ()
for block in self.double_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
img,
txt,
vec,
pe,
)
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
controlnet_block_res_samples = ()