diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 65993acf01..622d8ea60f 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -582,6 +582,8 @@ def invocation( fields: dict[str, tuple[Any, FieldInfo]] = {} + original_model_fields: dict[str, OriginalModelField] = {} + for field_name, field_info in cls.model_fields.items(): annotation = field_info.annotation assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation." @@ -589,7 +591,7 @@ def invocation( f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?" ) - cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info) + original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info) validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info) @@ -676,6 +678,7 @@ def invocation( docstring = cls.__doc__ new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore new_class.__doc__ = docstring + new_class._original_model_fields = original_model_fields InvocationRegistry.register_invocation(new_class) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index ec63021c69..4dabac964b 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -24,7 +24,6 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") -DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] @@ -93,7 +92,7 @@ class InvokeAIAppConfig(BaseSettings): vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable. lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable. pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally. - device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` + device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` @@ -176,7 +175,7 @@ class InvokeAIAppConfig(BaseSettings): pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.") # DEVICE - device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.") + device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e1ebb1856e..1fe9ebe1ef 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -296,7 +296,7 @@ class LoRAConfigBase(ABC, BaseModel): from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict sd = mod.load_state_dict(mod.path) - value = flux_format_from_state_dict(sd) + value = flux_format_from_state_dict(sd, mod.metadata()) mod.cache[key] = value return value diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 3c67986bee..e0dfd07dbe 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -21,6 +21,10 @@ from invokeai.backend.model_manager.taxonomy import ( ModelType, SubModelType, ) +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + is_state_dict_likely_in_flux_aitoolkit_format, + lora_model_from_flux_aitoolkit_state_dict, +) from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import ( is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict, @@ -99,6 +103,8 @@ class LoRALoader(ModelLoader): model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) elif is_state_dict_likely_flux_control(state_dict=state_dict): model = lora_model_from_flux_control_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict): + model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict) else: raise ValueError("LoRA model is in unsupported FLUX format") else: diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 91aaf007d5..5285d53c25 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -138,6 +138,7 @@ class FluxLoRAFormat(str, Enum): Kohya = "flux.kohya" OneTrainer = "flux.onetrainer" Control = "flux.control" + AIToolkit = "flux.aitoolkit" AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None] diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py new file mode 100644 index 0000000000..6ca06a0355 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -0,0 +1,63 @@ +import json +from dataclasses import dataclass, field +from typing import Any + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.util import InvokeAILogger + + +def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: + if metadata: + try: + software = json.loads(metadata.get("software", "{}")) + except json.JSONDecodeError: + return False + return software.get("name") == "ai-toolkit" + # metadata got lost somewhere + return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys()) + + +@dataclass +class GroupedStateDict: + transformer: dict[str, Any] = field(default_factory=dict) + # might also grow CLIP and T5 submodels + + +def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict: + logger = InvokeAILogger.get_logger() + grouped = GroupedStateDict() + for key, value in state_dict.items(): + submodel_name, param_name = key.split(".", 1) + match submodel_name: + case "diffusion_model": + grouped.transformer[param_name] = value + case _: + logger.warning(f"Unexpected submodel name: {submodel_name}") + return grouped + + +def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Renames keys from the PEFT LoRA format to the InvokeAI format.""" + renamed_state_dict = {} + for key, value in state_dict.items(): + renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.") + renamed_state_dict[renamed_key] = value + return renamed_state_dict + + +def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw: + state_dict = _rename_peft_lora_keys(state_dict) + by_layer = _group_by_layer(state_dict) + by_model = _group_state_by_submodel(by_layer) + + layers: dict[str, BaseLayerPatch] = {} + for layer_key, layer_state_dict in by_model.transformer.items(): + layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 4607310067..94f71e05ee 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -1,4 +1,7 @@ from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + is_state_dict_likely_in_flux_aitoolkit_format, +) from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( is_state_dict_likely_in_flux_diffusers_format, @@ -11,7 +14,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u ) -def flux_format_from_state_dict(state_dict): +def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): @@ -20,5 +23,7 @@ def flux_format_from_state_dict(state_dict): return FluxLoRAFormat.Diffusers elif is_state_dict_likely_flux_control(state_dict): return FluxLoRAFormat.Control + elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata): + return FluxLoRAFormat.AIToolkit else: return None diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton.tsx index 58dcfc0f1c..38ca25e41c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton.tsx @@ -19,6 +19,7 @@ export const CanvasToolbarSaveToGalleryButton = memo(() => { onClick={shift ? saveBboxToGallery : saveCanvasToGallery} icon={} aria-label={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')} + colorScheme="invokeBlue" tooltip={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')} isDisabled={isBusy} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/publish.ts b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/publish.ts index ba88099493..dadb12a72b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/publish.ts +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/publish.ts @@ -122,11 +122,11 @@ const NODE_TYPE_PUBLISH_DENYLIST = [ 'metadata_to_controlnets', 'metadata_to_ip_adapters', 'metadata_to_t2i_adapters', - 'google_imagen3_generate', - 'google_imagen3_edit', - 'google_imagen4_generate', - 'chatgpt_create_image', - 'chatgpt_edit_image', + 'google_imagen3_generate_image', + 'google_imagen3_edit_image', + 'google_imagen4_generate_image', + 'chatgpt_4o_generate_image', + 'chatgpt_4o_edit_image', ]; export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 9febef8269..cc5b91e2c7 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -12161,7 +12161,7 @@ export type components = { * vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable. * lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable. * pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally. - * device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` + * device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) * precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` * sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. * attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` @@ -12436,11 +12436,10 @@ export type components = { pytorch_cuda_alloc_conf?: string | null; /** * Device - * @description Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. + * @description Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) * @default auto - * @enum {string} */ - device?: "auto" | "cpu" | "cuda" | "cuda:1" | "mps"; + device?: string; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py index 51551062c0..b1eaa2d447 100644 --- a/invokeai/version/invokeai_version.py +++ b/invokeai/version/invokeai_version.py @@ -1 +1 @@ -__version__ = "5.14.0" +__version__ = "5.15.0" diff --git a/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_aitoolkit_format.py b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_aitoolkit_format.py new file mode 100644 index 0000000000..98b278df86 --- /dev/null +++ b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_aitoolkit_format.py @@ -0,0 +1,458 @@ +state_dict_keys = { + "diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.0.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.0.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.0.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.0.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.0.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.0.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.0.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.0.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.0.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.1.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.1.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.1.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.1.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.1.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.1.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.10.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.10.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.10.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.10.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.10.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.10.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.11.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.11.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.11.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.11.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.11.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.11.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.12.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.12.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.12.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.12.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.12.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.12.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.13.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.13.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.13.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.13.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.13.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.13.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.14.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.14.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.14.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.14.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.14.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.14.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.15.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.15.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.15.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.15.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.15.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.15.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.16.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.16.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.16.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.16.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.16.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.16.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.17.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.17.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.17.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.17.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.17.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.17.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.18.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.18.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.18.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.18.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.18.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.18.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.2.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.2.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.2.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.2.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.2.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.2.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.3.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.3.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.3.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.3.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.3.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.3.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.4.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.4.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.4.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.4.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.4.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.4.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.5.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.5.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.5.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.5.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.5.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.5.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.6.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.6.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.6.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.6.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.6.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.6.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.7.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.7.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.7.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.7.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.7.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.7.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.8.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.8.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.8.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.8.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.8.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.8.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.9.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.9.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.9.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.9.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.9.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.9.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.0.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.0.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.0.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.0.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.1.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.1.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.1.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.1.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.10.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.10.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.10.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.10.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.11.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.11.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.11.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.11.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.12.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.12.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.12.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.12.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.13.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.13.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.13.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.13.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.14.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.14.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.14.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.14.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.15.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.15.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.15.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.15.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.16.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.16.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.16.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.16.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.17.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.17.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.17.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.17.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.18.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.18.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.18.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.18.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.19.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.19.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.19.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.19.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.2.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.2.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.2.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.2.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.20.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.20.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.20.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.20.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.21.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.21.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.21.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.21.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.22.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.22.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.22.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.22.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.23.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.23.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.23.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.23.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.24.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.24.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.24.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.24.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.25.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.25.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.25.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.25.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.26.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.26.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.26.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.26.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.27.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.27.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.27.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.27.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.28.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.28.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.28.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.28.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.29.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.29.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.29.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.29.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.3.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.3.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.3.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.3.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.30.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.30.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.30.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.30.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.31.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.31.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.31.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.31.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.32.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.32.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.32.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.32.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.33.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.33.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.33.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.33.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.34.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.34.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.34.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.34.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.35.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.35.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.35.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.35.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.36.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.36.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.36.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.36.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.37.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.37.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.37.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.37.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.4.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.4.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.4.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.4.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.5.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.5.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.5.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.5.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.6.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.6.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.6.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.6.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.7.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.7.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.7.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.7.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.8.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.8.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.8.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.8.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.9.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.9.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.9.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.9.linear2.lora_B.weight": [3072, 16], +} diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py new file mode 100644 index 0000000000..ed3e05a9b2 --- /dev/null +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -0,0 +1,59 @@ +import accelerate +import pytest + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import params +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + _group_state_by_submodel, + is_state_dict_likely_in_flux_aitoolkit_format, + lora_model_from_flux_aitoolkit_state_dict, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( + state_dict_keys as flux_onetrainer_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import ( + state_dict_keys as flux_aitoolkit_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import ( + state_dict_keys as flux_diffusers_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict + + +def test_is_state_dict_likely_in_flux_aitoolkit_format(): + state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys) + assert is_state_dict_likely_in_flux_aitoolkit_format(state_dict) + + +@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys]) +def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]): + state_dict = keys_to_mock_state_dict(sd_keys) + assert not is_state_dict_likely_in_flux_aitoolkit_format(state_dict) + + +def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format(): + state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys) + converted_state_dict = _group_state_by_submodel(state_dict).transformer + + # Extract the prefixes from the converted state dict (without the lora suffixes) + converted_key_prefixes: list[str] = [] + for k in converted_state_dict.keys(): + k = k.replace(".lora_A.weight", "") + k = k.replace(".lora_B.weight", "") + converted_key_prefixes.append(k) + + # Initialize a FLUX model on the meta device. + with accelerate.init_empty_weights(): + model = Flux(params["flux-schnell"]) + model_keys = set(model.state_dict().keys()) + + for converted_key_prefix in converted_key_prefixes: + assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), ( + f"'{converted_key_prefix}' did not match any model keys." + ) + + +def test_lora_model_from_flux_aitoolkit_state_dict(): + state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys) + + assert lora_model_from_flux_aitoolkit_state_dict(state_dict) diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 8e810e4367..b65137c08d 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -10,7 +10,7 @@ import torch from invokeai.app.services.config import get_config from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype -devices = ["cpu", "cuda:0", "cuda:1", "mps"] +devices = ["cpu", "cuda:0", "cuda:1", "cuda:2", "mps"] device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)] device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)] device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]