mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Rename DiffusersControlNetFlux -> InstantXControlNetFlux.
This commit is contained in:
@@ -18,7 +18,7 @@ from invokeai.backend.flux.modules.layers import (
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusersControlNetFluxOutput:
|
||||
class InstantXControlNetFluxOutput:
|
||||
controlnet_block_samples: list[torch.Tensor] | None
|
||||
controlnet_single_block_samples: list[torch.Tensor] | None
|
||||
|
||||
@@ -36,7 +36,7 @@ class DiffusersControlNetFluxOutput:
|
||||
# - axes_dims_rope: axes_dim
|
||||
|
||||
|
||||
class DiffusersControlNetFlux(torch.nn.Module):
|
||||
class InstantXControlNetFlux(torch.nn.Module):
|
||||
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
|
||||
"""
|
||||
Args:
|
||||
@@ -114,7 +114,7 @@ class DiffusersControlNetFlux(torch.nn.Module):
|
||||
timesteps: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
guidance: torch.Tensor | None = None,
|
||||
) -> DiffusersControlNetFluxOutput:
|
||||
) -> InstantXControlNetFluxOutput:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
@@ -168,7 +168,7 @@ class DiffusersControlNetFlux(torch.nn.Module):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples.append(single_block_sample)
|
||||
|
||||
return DiffusersControlNetFluxOutput(
|
||||
return InstantXControlNetFluxOutput(
|
||||
controlnet_block_samples=controlnet_double_block_samples or None,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples or None,
|
||||
)
|
||||
@@ -173,11 +173,11 @@ def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
|
||||
|
||||
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
|
||||
DiffusersControlNetFlux.
|
||||
InstantXControlNetFlux model.
|
||||
|
||||
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
|
||||
implementation to DiffusersControlNetFlux to make it compatible with BFL-style models. This function converts the
|
||||
original state dict to the format expected by DiffusersControlNetFlux.
|
||||
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
|
||||
original state dict to the format expected by InstantXControlNetFlux.
|
||||
"""
|
||||
# Shallow copy sd so that we can pop keys from it without modifying the original.
|
||||
sd = sd.copy()
|
||||
|
||||
@@ -10,7 +10,7 @@ from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
@@ -341,7 +341,7 @@ class FluxControlnetModel(ModelLoader):
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = DiffusersControlNetFlux(flux_params, num_control_modes)
|
||||
model = InstantXControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
@@ -89,7 +89,7 @@ def test_load_instantx_from_state_dict():
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = DiffusersControlNetFlux(flux_params, num_control_modes)
|
||||
model = InstantXControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model_sd = model.state_dict()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user