refactor(mm): diffusers loras

w
This commit is contained in:
psychedelicious
2025-10-01 18:00:50 +10:00
parent 629db4acfe
commit 315ddefbf1
9 changed files with 102 additions and 46 deletions

View File

@@ -29,7 +29,13 @@ from invokeai.app.services.model_records import (
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import AnyModelConfig, SD_1_2_XL_XLRefiner_CheckpointConfig
from invokeai.backend.model_manager.config import (
AnyModelConfig,
Main_SD1_Checkpoint_Config,
Main_SD2_Checkpoint_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXLRefiner_Checkpoint_Config,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
@@ -738,7 +744,15 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
if isinstance(model_config, SD_1_2_XL_XLRefiner_CheckpointConfig):
if isinstance(
model_config,
(
Main_SD1_Checkpoint_Config,
Main_SD2_Checkpoint_Config,
Main_SDXL_Checkpoint_Config,
Main_SDXLRefiner_Checkpoint_Config,
),
):
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
logger.error(msg)
raise HTTPException(400, msg)

View File

@@ -17,8 +17,7 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapter_InvokeAI_Config_Base,
IPAdapterCheckpointConfig,
IPAdapter_FLUX_Checkpoint_Config,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@@ -68,7 +67,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, IPAdapter_FLUX_Checkpoint_Config)
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

View File

@@ -13,8 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
IPAdapter_Checkpoint_Config_Base,
IPAdapter_InvokeAI_Config_Base,
IPAdapterCheckpointConfig,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
@@ -123,7 +123,7 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base))
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id

View File

@@ -748,44 +748,78 @@ class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase)
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
# LoRADiffusers_SupportedBases: TypeAlias = Literal[
# BaseModelType.StableDiffusion1,
# BaseModelType.StableDiffusion2,
# BaseModelType.StableDiffusionXL,
# BaseModelType.Flux,
# ]
class LoRA_Diffusers_Config_Base(LoRAConfigBase):
"""Model config for LoRA/Diffusers models."""
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
# the weights format. FLUX Diffusers LoRAs are single files.
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_dir(cls, mod)
_validate_override_fields(cls, fields)
cls._validate_base(mod)
return cls(**fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default.value
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
if _get_flux_lora_format(mod):
return BaseModelType.Flux
# If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
path_to_weight_file = cls._get_weight_file_or_raise(mod)
state_dict = mod.load_state_dict(path_to_weight_file)
token_vector_length = lora_token_vector_length(state_dict)
match token_vector_length:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
@classmethod
def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path:
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
for wf in weight_files:
if wf.exists():
return wf
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
# class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
# """Model config for LoRA/Diffusers models."""
class LoRA_SD1_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
# # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
# # the weights format. FLUX Diffusers LoRAs are single files.
# base: LoRADiffusers_SupportedBases = Field()
# format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
class LoRA_SD2_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
# @classmethod
# def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
# _validate_is_dir(cls, mod)
# _validate_override_fields(cls, fields)
class LoRA_SDXL_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
# cls._validate_looks_like_diffusers_lora(mod)
# return cls(**fields)
# @classmethod
# def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
# suffixes = ["bin", "safetensors"]
# weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
# has_lora_weight_file = any(wf.exists() for wf in weight_files)
# if not has_lora_weight_file:
# raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
# flux_lora_format = _get_flux_lora_format(mod)
# if flux_lora_format is not FluxLoRAFormat.Diffusers:
# raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
class LoRA_FLUX_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
@@ -2332,8 +2366,11 @@ AnyModelConfig = Annotated[
# LoRA - OMI format
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
# LoRA - diffusers format (TODO)
# Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
# LoRA - diffusers format
Annotated[LoRA_SD1_Diffusers_Config, LoRA_SD1_Diffusers_Config.get_tag()],
Annotated[LoRA_SD2_Diffusers_Config, LoRA_SD2_Diffusers_Config.get_tag()],
Annotated[LoRA_SDXL_Diffusers_Config, LoRA_SDXL_Diffusers_Config.get_tag()],
Annotated[LoRA_FLUX_Diffusers_Config, LoRA_FLUX_Diffusers_Config.get_tag()],
# ControlLoRA - diffusers format
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],

View File

@@ -13,7 +13,7 @@ from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_flux_aitoolkit_format(
state_dict: dict[str, Any],
state_dict: dict[str | int, Any],
metadata: dict[str, Any] | None = None,
) -> bool:
if metadata:
@@ -23,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format(
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str))
@dataclass

View File

@@ -25,7 +25,9 @@ def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool:
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
all_keys_match = all(
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str)
)
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)

View File

@@ -9,14 +9,16 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_L
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
all_keys_in_peft_format = all(
k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str)
)
# Check if keys use transformer prefix
transformer_prefix_keys = [

View File

@@ -44,7 +44,7 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
@@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
if isinstance(k, str)
)

View File

@@ -40,7 +40,7 @@ FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
)
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
@@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
if isinstance(k, str)
)