mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 11:38:01 -05:00
Add unit test to test the full flow of loading an InstantX ControlNet from a state dict.
This commit is contained in:
@@ -89,11 +89,11 @@ class DiffusersControlNetFlux(torch.nn.Module):
|
||||
# The following modules are specific to the ControlNet model.
|
||||
# -----------------------------------------------------------
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
for _ in range(len(self.double_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
for _ in range(len(self.single_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
|
||||
|
||||
self.is_union = False
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.controlnet.diffusers_controlnet_flux import DiffusersControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
@@ -46,7 +47,7 @@ def test_convert_diffusers_instantx_state_dict_to_bfl_format():
|
||||
|
||||
|
||||
def test_infer_flux_params_from_state_dict():
|
||||
# Construct a dummy state_dict with tensor of the correct shape on the meta device.
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
@@ -68,7 +69,7 @@ def test_infer_flux_params_from_state_dict():
|
||||
|
||||
|
||||
def test_infer_instantx_num_control_modes_from_state_dict():
|
||||
# Construct a dummy state_dict with tensor of the correct shape on the meta device.
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
@@ -76,3 +77,23 @@ def test_infer_instantx_num_control_modes_from_state_dict():
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
assert num_control_modes == instantx_config["num_mode"]
|
||||
|
||||
|
||||
def test_load_instantx_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = DiffusersControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model_sd = model.state_dict()
|
||||
|
||||
assert set(model_sd.keys()) == set(sd.keys())
|
||||
for key, tensor in model_sd.items():
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.shape == sd[key].shape
|
||||
|
||||
Reference in New Issue
Block a user