From 45d09f8f51fb3107b68cb802b87742a74443d48d Mon Sep 17 00:00:00 2001 From: Billy Date: Thu, 19 Jun 2025 09:40:49 +1000 Subject: [PATCH] Use OMI conversion utils --- invokeai/backend/model_manager/config.py | 3 - .../model_manager/load/model_loaders/lora.py | 5 +- invokeai/backend/model_manager/omi.py | 57 +++++-------------- pyproject.toml | 1 + 4 files changed, 18 insertions(+), 48 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 83f7ac61e9..b995815d1f 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -357,15 +357,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): if "stable-diffusion-v1" in base_str: base = BaseModelType.StableDiffusion1 - elif "stable-diffusion-v2" in base_str: - base = BaseModelType.StableDiffusion2 elif "stable-diffusion-v3" in base_str: base = BaseModelType.StableDiffusion3 elif base_str == "stable-diffusion-xl-v1-base": base = BaseModelType.StableDiffusionXL elif "flux" in base_str: base = BaseModelType.Flux - else: raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}") diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 2609c01f15..dee1717709 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -13,7 +13,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.omi import convert_from_omi +from invokeai.backend.model_manager.omi import convert_to_omi from invokeai.backend.model_manager.taxonomy import ( AnyModel, BaseModelType, @@ -41,7 +41,6 @@ from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.LoRA, format=ModelFormat.OMI) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.LoRA, format=ModelFormat.OMI) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.LoRA, format=ModelFormat.OMI) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI) @@ -80,7 +79,7 @@ class LoRALoader(ModelLoader): state_dict = torch.load(model_path, map_location="cpu") if config.format == ModelFormat.OMI: - state_dict = convert_from_omi(state_dict) + state_dict = convert_to_omi(state_dict. config.base) # type: ignore # Apply state_dict key conversions, if necessary. if self._model_base == BaseModelType.StableDiffusionXL: diff --git a/invokeai/backend/model_manager/omi.py b/invokeai/backend/model_manager/omi.py index 8698937ace..477294c6ec 100644 --- a/invokeai/backend/model_manager/omi.py +++ b/invokeai/backend/model_manager/omi.py @@ -1,44 +1,17 @@ -import torch - -from invokeai.backend.util.logging import InvokeAILogger - -logger = InvokeAILogger.get_logger() +from invokeai.backend.model_manager.model_on_disk import StateDict +from invokeai.backend.model_manager.taxonomy import BaseModelType +from omi_model_standards.convert.lora.convert_sdxl_lora import convert_sdxl_lora_key_sets +from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets +from omi_model_standards.convert.lora.convert_sd_lora import convert_sd_lora_key_sets +from omi_model_standards.convert.lora.convert_sd3_lora import convert_sd3_lora_key_sets +import omi_model_standards.convert.lora.convert_lora_util as lora_util -def convert_from_omi(weights_sd): - # convert from OMI to default LoRA - # OMI format: {"prefix.module.name.lora_down.weight": weight, "prefix.module.name.lora_up.weight": weight, ...} - # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...} - - new_weights_sd = {} - prefix = "lora_unet_" - lora_dims = {} - weight_dtype = None - for key, weight in weights_sd.items(): - omi_prefix, key_body = key.split(".", 1) - if omi_prefix != "diffusion": - logger.warning(f"unexpected key: {key} in OMI format") # T5, CLIP, etc. - continue - - # only supports lora_down, lora_up and alpha - new_key = ( - f"{prefix}{key_body}".replace(".", "_") - .replace("_lora_down_", ".lora_down.") - .replace("_lora_up_", ".lora_up.") - .replace("_alpha", ".alpha") - ) - new_weights_sd[new_key] = weight - - lora_name = new_key.split(".")[0] # before first dot - if lora_name not in lora_dims and "lora_down" in new_key: - lora_dims[lora_name] = weight.shape[0] - if weight_dtype is None: - weight_dtype = weight.dtype # use first weight dtype for lora_down - - # add alpha with rank - for lora_name, dim in lora_dims.items(): - alpha_key = f"{lora_name}.alpha" - if alpha_key not in new_weights_sd: - new_weights_sd[alpha_key] = torch.tensor(dim, dtype=weight_dtype) - - return new_weights_sd +def convert_to_omi(weights_sd: StateDict, base: BaseModelType): + keyset = { + BaseModelType.Flux: convert_flux_lora_key_sets(), + BaseModelType.StableDiffusionXL: convert_sdxl_lora_key_sets(), + BaseModelType.StableDiffusion1: convert_sd_lora_key_sets(), + BaseModelType.StableDiffusion3: convert_sd3_lora_key_sets(), + }[base] + return lora_util.convert_to_omi(weights_sd, keyset) diff --git a/pyproject.toml b/pyproject.toml index 69f6a4150f..045ea40639 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "python-multipart", "requests", "semver~=3.0.1", + "omi-model-standards @ git+https://github.com/Open-Model-Initiative/OMI-Model-Standards.git@4ad235ceba6b42a97942834b7664379e4ec2d93c" ] [project.optional-dependencies]