From 85c4304efd6a7651cf2a998402f1c2eb8f6f5ecd Mon Sep 17 00:00:00 2001 From: Billy Date: Tue, 17 Jun 2025 13:34:03 +1000 Subject: [PATCH] Add OMI LoRA config --- invokeai/backend/model_manager/config.py | 40 ++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index f4c057cb5b..656091ee2e 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -31,6 +31,7 @@ from pathlib import Path from typing import ClassVar, Literal, Optional, TypeAlias, Union from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter +from pydantic.main import Model from typing_extensions import Annotated, Any, Dict from invokeai.app.util.misc import uuid_string @@ -334,6 +335,44 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b +class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): + format: Literal[ModelFormat.OMI] = ModelFormat.OMI + + @classmethod + def matches(cls, mod: ModelOnDisk) -> bool: + if mod.path.is_dir(): + return False + + metadata = mod.metadata() + return ( + metadata.get("modelspec.sai_model_spec") and + metadata.get("ot_branch") == "omi_format" and + metadata["modelspec.architecture"].split("/")[1].lower() == "lora" + ) + + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + metadata = mod.metadata() + base_str, _ = metadata["modelspec.architecture"].split("/") + base_str = base_str.lower() + + if "stable-diffusion-v1" in base_str: + base = BaseModelType.StableDiffusion1 + elif "stable-diffusion-v2" in base_str: + base = BaseModelType.StableDiffusion2 + elif "stable-diffusion-v3" in base_str: + base = BaseModelType.StableDiffusion3 + elif base_str == "stable-diffusion-xl-v1-base": + base = BaseModelType.StableDiffusionXL + elif "flux" in base_str: + base = BaseModelType.Flux + + else: + raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}") + + return { "base": base } + + class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -668,6 +707,7 @@ AnyModelConfig = Annotated[ Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], + Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()], Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()], Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],