diff --git a/invokeai/backend/lora/conversions/flux_control_lora_utils.py b/invokeai/backend/lora/conversions/flux_control_lora_utils.py index b52ad47a0b..1f3222a1be 100644 --- a/invokeai/backend/lora/conversions/flux_control_lora_utils.py +++ b/invokeai/backend/lora/conversions/flux_control_lora_utils.py @@ -14,7 +14,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw # guidance_in.in_layer.lora_B.bias # single_blocks.0.linear1.lora_A.weight # double_blocks.0.img_attn.norm.key_norm.scale -FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)" +FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)" def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool: @@ -23,7 +23,23 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool: This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ - return all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys()) + + all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys()) + + # Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs. + lora_a_weight = state_dict.get("img_in.lora_A.weight", None) + lora_b_bias = state_dict.get("img_in.lora_B.bias", None) + lora_b_weight = state_dict.get("img_in.lora_B.weight", None) + + return ( + all_keys_match + and lora_a_weight is not None + and lora_b_bias is not None + and lora_b_weight is not None + and lora_a_weight.shape[1] == 128 + and lora_b_weight.shape[0] == 3072 + and lora_b_bias.shape[0] == 3072 + ) def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw: diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 268c94f410..b03367147b 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -259,17 +259,8 @@ class ModelProbe(object): ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) ckpt = ckpt.get("state_dict", ckpt) - if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt: - tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"] - if ( - tensor_a is not None - and isinstance(tensor_a, torch.Tensor) - and tensor_a.shape[1] == 128 - and tensor_b is not None - and isinstance(tensor_b, torch.Tensor) - and tensor_b.shape[0] == 3072 - ): - return ModelType.ControlLoRa + if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt): + return ModelType.ControlLoRa for key in [str(k) for k in ckpt.keys()]: if key.startswith(