Rename DiffusersControlNetFlux -> InstantXControlNetFlux.

This commit is contained in:
Ryan Dick
2024-10-07 21:20:01 +00:00
committed by Kent Keirsey
parent d75ac56d00
commit 44c588d778
4 changed files with 11 additions and 11 deletions

View File

@@ -18,7 +18,7 @@ from invokeai.backend.flux.modules.layers import (
@dataclass
class DiffusersControlNetFluxOutput:
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
@@ -36,7 +36,7 @@ class DiffusersControlNetFluxOutput:
# - axes_dims_rope: axes_dim
class DiffusersControlNetFlux(torch.nn.Module):
class InstantXControlNetFlux(torch.nn.Module):
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
"""
Args:
@@ -114,7 +114,7 @@ class DiffusersControlNetFlux(torch.nn.Module):
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> DiffusersControlNetFluxOutput:
) -> InstantXControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -168,7 +168,7 @@ class DiffusersControlNetFlux(torch.nn.Module):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples.append(single_block_sample)
return DiffusersControlNetFluxOutput(
return InstantXControlNetFluxOutput(
controlnet_block_samples=controlnet_double_block_samples or None,
controlnet_single_block_samples=controlnet_single_block_samples or None,
)

View File

@@ -173,11 +173,11 @@ def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
DiffusersControlNetFlux.
InstantXControlNetFlux model.
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
implementation to DiffusersControlNetFlux to make it compatible with BFL-style models. This function converts the
original state dict to the format expected by DiffusersControlNetFlux.
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
original state dict to the format expected by InstantXControlNetFlux.
"""
# Shallow copy sd so that we can pop keys from it without modifying the original.
sd = sd.copy()

View File

@@ -10,7 +10,7 @@ from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
@@ -341,7 +341,7 @@ class FluxControlnetModel(ModelLoader):
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with accelerate.init_empty_weights():
model = DiffusersControlNetFlux(flux_params, num_control_modes)
model = InstantXControlNetFlux(flux_params, num_control_modes)
model.load_state_dict(sd, assign=True)
return model

View File

@@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
@@ -89,7 +89,7 @@ def test_load_instantx_from_state_dict():
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with torch.device("meta"):
model = DiffusersControlNetFlux(flux_params, num_control_modes)
model = InstantXControlNetFlux(flux_params, num_control_modes)
model_sd = model.state_dict()