diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 071713316a..94cfa34803 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -11,6 +11,7 @@ from safetensors.torch import load_file from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager.config import AnyModelConfig from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.omi import convert_from_omi from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.taxonomy import ( @@ -73,6 +74,10 @@ class LoRALoader(ModelLoader): else: state_dict = torch.load(model_path, map_location="cpu") + if config.format == ModelFormat.OMI: + state_dict = convert_from_omi(state_dict) + + # Apply state_dict key conversions, if necessary. if self._model_base == BaseModelType.StableDiffusionXL: state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) @@ -85,7 +90,7 @@ class LoRALoader(ModelLoader): # is a popular choice. For example, in the diffusers training scripts: # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194 model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) - elif config.format == ModelFormat.LyCORIS: + elif config.format in [ModelFormat.LyCORIS, ModelFormat.OMI]: if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): diff --git a/invokeai/backend/model_manager/omi.py b/invokeai/backend/model_manager/omi.py new file mode 100644 index 0000000000..44abad4a99 --- /dev/null +++ b/invokeai/backend/model_manager/omi.py @@ -0,0 +1,43 @@ +import torch +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger() + + +def convert_from_omi(weights_sd): + # convert from OMI to default LoRA + # OMI format: {"prefix.module.name.lora_down.weight": weight, "prefix.module.name.lora_up.weight": weight, ...} + # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...} + + new_weights_sd = {} + prefix = "lora_unet_" + lora_dims = {} + weight_dtype = None + for key, weight in weights_sd.items(): + omi_prefix, key_body = key.split(".", 1) + if omi_prefix != "diffusion": + logger.warning(f"unexpected key: {key} in OMI format") # T5, CLIP, etc. + continue + + # only supports lora_down, lora_up and alpha + new_key = ( + f"{prefix}{key_body}".replace(".", "_") + .replace("_lora_down_", ".lora_down.") + .replace("_lora_up_", ".lora_up.") + .replace("_alpha", ".alpha") + ) + new_weights_sd[new_key] = weight + + lora_name = new_key.split(".")[0] # before first dot + if lora_name not in lora_dims and "lora_down" in new_key: + lora_dims[lora_name] = weight.shape[0] + if weight_dtype is None: + weight_dtype = weight.dtype # use first weight dtype for lora_down + + # add alpha with rank + for lora_name, dim in lora_dims.items(): + alpha_key = f"{lora_name}.alpha" + if alpha_key not in new_weights_sd: + new_weights_sd[alpha_key] = torch.tensor(dim, dtype=weight_dtype) + + return new_weights_sd