mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Remove FluxMultiControlNetModel
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
# 2. We need to sort out https://github.com/invoke-ai/InvokeAI/pull/6740 before we can bump the diffusers package.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -308,121 +308,3 @@ class DiffusersControlNetFlux(ModelMixin, ConfigMixin):
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
|
||||
|
||||
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `FluxControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[FluxControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`FluxControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(
|
||||
control_block_samples, block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
else:
|
||||
for i, (image, mode, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(
|
||||
control_block_samples, block_samples, strict=False
|
||||
)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
|
||||
Reference in New Issue
Block a user