diff --git a/invokeai/backend/flux/controlnet/diffusers_controlnet_flux.py b/invokeai/backend/flux/controlnet/diffusers_controlnet_flux.py index be5e7c0fc3..dc32d8cb79 100644 --- a/invokeai/backend/flux/controlnet/diffusers_controlnet_flux.py +++ b/invokeai/backend/flux/controlnet/diffusers_controlnet_flux.py @@ -89,11 +89,11 @@ class DiffusersControlNetFlux(torch.nn.Module): # The following modules are specific to the ControlNet model. # ----------------------------------------------------------- self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): + for _ in range(len(self.double_blocks)): self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size))) self.controlnet_single_blocks = nn.ModuleList([]) - for _ in range(len(self.single_transformer_blocks)): + for _ in range(len(self.single_blocks)): self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size))) self.is_union = False diff --git a/tests/backend/flux/controlnet/test_state_dict_utils.py b/tests/backend/flux/controlnet/test_state_dict_utils.py index 8dfbfa18e5..89248b4c2c 100644 --- a/tests/backend/flux/controlnet/test_state_dict_utils.py +++ b/tests/backend/flux/controlnet/test_state_dict_utils.py @@ -1,6 +1,7 @@ import pytest import torch +from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux from invokeai.backend.flux.controlnet.state_dict_utils import ( convert_diffusers_instantx_state_dict_to_bfl_format, infer_flux_params_from_state_dict, @@ -46,7 +47,7 @@ def test_convert_diffusers_instantx_state_dict_to_bfl_format(): def test_infer_flux_params_from_state_dict(): - # Construct a dummy state_dict with tensor of the correct shape on the meta device. + # Construct a dummy state_dict with tensors of the correct shape on the meta device. with torch.device("meta"): sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()} @@ -68,7 +69,7 @@ def test_infer_flux_params_from_state_dict(): def test_infer_instantx_num_control_modes_from_state_dict(): - # Construct a dummy state_dict with tensor of the correct shape on the meta device. + # Construct a dummy state_dict with tensors of the correct shape on the meta device. with torch.device("meta"): sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()} @@ -76,3 +77,23 @@ def test_infer_instantx_num_control_modes_from_state_dict(): num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) assert num_control_modes == instantx_config["num_mode"] + + +def test_load_instantx_from_state_dict(): + # Construct a dummy state_dict with tensors of the correct shape on the meta device. + with torch.device("meta"): + sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()} + + sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd) + flux_params = infer_flux_params_from_state_dict(sd) + num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) + + with torch.device("meta"): + model = DiffusersControlNetFlux(flux_params, num_control_modes) + + model_sd = model.state_dict() + + assert set(model_sd.keys()) == set(sd.keys()) + for key, tensor in model_sd.items(): + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == sd[key].shape