From 4be3a337442cd3867fb3bc134cd3ffcae37c7d98 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 4 Oct 2024 19:19:56 +0000 Subject: [PATCH] Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from state dict. --- .../flux/controlnet/state_dict_utils.py | 41 +++++++++++++++++++ .../instantx_flux_controlnet_state_dict.py | 2 +- .../flux/controlnet/test_state_dict_utils.py | 34 +++++++++++++++ .../xlabs_flux_controlnet_state_dict.py | 2 +- 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 invokeai/backend/flux/controlnet/state_dict_utils.py create mode 100644 tests/backend/flux/controlnet/test_state_dict_utils.py diff --git a/invokeai/backend/flux/controlnet/state_dict_utils.py b/invokeai/backend/flux/controlnet/state_dict_utils.py new file mode 100644 index 0000000000..773dff76c2 --- /dev/null +++ b/invokeai/backend/flux/controlnet/state_dict_utils.py @@ -0,0 +1,41 @@ +from typing import Any, Dict + + +def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool: + """Is the state dict for an XLabs ControlNet model? + + This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. + """ + # If all of the expected keys are present, then this is very likely an XLabs ControlNet model. + expected_keys = { + "controlnet_blocks.0.bias", + "controlnet_blocks.0.weight", + "input_hint_block.0.bias", + "input_hint_block.0.weight", + "pos_embed_input.bias", + "pos_embed_input.weight", + } + + if expected_keys.issubset(sd.keys()): + return True + return False + + +def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool: + """Is the state dict for an InstantX ControlNet model? + + This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. + """ + # If all of the expected keys are present, then this is very likely an InstantX ControlNet model. + expected_keys = { + "controlnet_blocks.0.bias", + "controlnet_blocks.0.weight", + "controlnet_single_blocks.0.bias", + "controlnet_single_blocks.0.weight", + "controlnet_x_embedder.bias", + "controlnet_x_embedder.weight", + } + + if expected_keys.issubset(sd.keys()): + return True + return False diff --git a/tests/backend/flux/controlnet/instantx_flux_controlnet_state_dict.py b/tests/backend/flux/controlnet/instantx_flux_controlnet_state_dict.py index bb913b9958..13633bc0b9 100644 --- a/tests/backend/flux/controlnet/instantx_flux_controlnet_state_dict.py +++ b/tests/backend/flux/controlnet/instantx_flux_controlnet_state_dict.py @@ -1,7 +1,7 @@ # State dict keys for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests. # These keys were extracted from: # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors -state_dict_keys = [ +instantx_state_dict_keys = [ "context_embedder.bias", "context_embedder.weight", "controlnet_blocks.0.bias", diff --git a/tests/backend/flux/controlnet/test_state_dict_utils.py b/tests/backend/flux/controlnet/test_state_dict_utils.py new file mode 100644 index 0000000000..2a7cf32e10 --- /dev/null +++ b/tests/backend/flux/controlnet/test_state_dict_utils.py @@ -0,0 +1,34 @@ +import pytest + +from invokeai.backend.flux.controlnet.state_dict_utils import ( + is_state_dict_instantx_controlnet, + is_state_dict_xlabs_controlnet, +) +from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_state_dict_keys +from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_state_dict_keys + + +@pytest.mark.parametrize( + ["sd_keys", "expected"], + [ + (xlabs_state_dict_keys, True), + (instantx_state_dict_keys, False), + (["foo"], False), + ], +) +def test_is_state_dict_xlabs_controlnet(sd_keys: list[str], expected: bool): + sd = {k: None for k in sd_keys} + assert is_state_dict_xlabs_controlnet(sd) == expected + + +@pytest.mark.parametrize( + ["sd_keys", "expected"], + [ + (instantx_state_dict_keys, True), + (xlabs_state_dict_keys, False), + (["foo"], False), + ], +) +def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool): + sd = {k: None for k in sd_keys} + assert is_state_dict_instantx_controlnet(sd) == expected diff --git a/tests/backend/flux/controlnet/xlabs_flux_controlnet_state_dict.py b/tests/backend/flux/controlnet/xlabs_flux_controlnet_state_dict.py index b4eb4ccdc8..c0fcca74dc 100644 --- a/tests/backend/flux/controlnet/xlabs_flux_controlnet_state_dict.py +++ b/tests/backend/flux/controlnet/xlabs_flux_controlnet_state_dict.py @@ -1,7 +1,7 @@ # State dict keys for an XLabs FLUX ControlNet model. Intended to be used for unit tests. # These keys were extracted from: # https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors -state_dict_keys = [ +xlabs_state_dict_keys = [ "controlnet_blocks.0.bias", "controlnet_blocks.0.weight", "controlnet_blocks.1.bias",