mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(mm): diffusers loras
w
This commit is contained in:
@@ -29,7 +29,13 @@ from invokeai.app.services.model_records import (
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SD_1_2_XL_XLRefiner_CheckpointConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
Main_SD1_Checkpoint_Config,
|
||||
Main_SD2_Checkpoint_Config,
|
||||
Main_SDXL_Checkpoint_Config,
|
||||
Main_SDXLRefiner_Checkpoint_Config,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
@@ -738,7 +744,15 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
|
||||
if isinstance(model_config, SD_1_2_XL_XLRefiner_CheckpointConfig):
|
||||
if isinstance(
|
||||
model_config,
|
||||
(
|
||||
Main_SD1_Checkpoint_Config,
|
||||
Main_SD2_Checkpoint_Config,
|
||||
Main_SDXL_Checkpoint_Config,
|
||||
Main_SDXLRefiner_Checkpoint_Config,
|
||||
),
|
||||
):
|
||||
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
|
||||
logger.error(msg)
|
||||
raise HTTPException(400, msg)
|
||||
|
||||
@@ -17,8 +17,7 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
IPAdapter_InvokeAI_Config_Base,
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapter_FLUX_Checkpoint_Config,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
@@ -68,7 +67,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
|
||||
assert isinstance(ip_adapter_info, IPAdapter_FLUX_Checkpoint_Config)
|
||||
|
||||
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
|
||||
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
|
||||
@@ -13,8 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
IPAdapter_Checkpoint_Config_Base,
|
||||
IPAdapter_InvokeAI_Config_Base,
|
||||
IPAdapterCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
StarterModel,
|
||||
@@ -123,7 +123,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
|
||||
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base))
|
||||
|
||||
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
|
||||
@@ -748,44 +748,78 @@ class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapterConfigBase, ModelConfigBase)
|
||||
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
|
||||
|
||||
|
||||
# LoRADiffusers_SupportedBases: TypeAlias = Literal[
|
||||
# BaseModelType.StableDiffusion1,
|
||||
# BaseModelType.StableDiffusion2,
|
||||
# BaseModelType.StableDiffusionXL,
|
||||
# BaseModelType.Flux,
|
||||
# ]
|
||||
class LoRA_Diffusers_Config_Base(LoRAConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
|
||||
# the weights format. FLUX Diffusers LoRAs are single files.
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
cls._validate_base(mod)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
"""Raise `NotAMatch` if the model base does not match this config class."""
|
||||
expected_base = cls.model_fields["base"].default.value
|
||||
recognized_base = cls._get_base_or_raise(mod)
|
||||
if expected_base is not recognized_base:
|
||||
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
if _get_flux_lora_format(mod):
|
||||
return BaseModelType.Flux
|
||||
|
||||
# If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
|
||||
path_to_weight_file = cls._get_weight_file_or_raise(mod)
|
||||
state_dict = mod.load_state_dict(path_to_weight_file)
|
||||
token_vector_length = lora_token_vector_length(state_dict)
|
||||
|
||||
match token_vector_length:
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
case 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
|
||||
|
||||
@classmethod
|
||||
def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path:
|
||||
suffixes = ["bin", "safetensors"]
|
||||
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
for wf in weight_files:
|
||||
if wf.exists():
|
||||
return wf
|
||||
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
|
||||
|
||||
|
||||
# class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
# """Model config for LoRA/Diffusers models."""
|
||||
class LoRA_SD1_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
||||
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
|
||||
|
||||
# # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
|
||||
# # the weights format. FLUX Diffusers LoRAs are single files.
|
||||
|
||||
# base: LoRADiffusers_SupportedBases = Field()
|
||||
# format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
class LoRA_SD2_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
||||
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
|
||||
|
||||
# @classmethod
|
||||
# def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
# _validate_is_dir(cls, mod)
|
||||
|
||||
# _validate_override_fields(cls, fields)
|
||||
class LoRA_SDXL_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
||||
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
||||
|
||||
# cls._validate_looks_like_diffusers_lora(mod)
|
||||
|
||||
# return cls(**fields)
|
||||
|
||||
# @classmethod
|
||||
# def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
|
||||
# suffixes = ["bin", "safetensors"]
|
||||
# weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
# has_lora_weight_file = any(wf.exists() for wf in weight_files)
|
||||
# if not has_lora_weight_file:
|
||||
# raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
|
||||
|
||||
# flux_lora_format = _get_flux_lora_format(mod)
|
||||
# if flux_lora_format is not FluxLoRAFormat.Diffusers:
|
||||
# raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
|
||||
class LoRA_FLUX_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
|
||||
|
||||
class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
|
||||
@@ -2332,8 +2366,11 @@ AnyModelConfig = Annotated[
|
||||
# LoRA - OMI format
|
||||
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
|
||||
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
|
||||
# LoRA - diffusers format (TODO)
|
||||
# Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
# LoRA - diffusers format
|
||||
Annotated[LoRA_SD1_Diffusers_Config, LoRA_SD1_Diffusers_Config.get_tag()],
|
||||
Annotated[LoRA_SD2_Diffusers_Config, LoRA_SD2_Diffusers_Config.get_tag()],
|
||||
Annotated[LoRA_SDXL_Diffusers_Config, LoRA_SDXL_Diffusers_Config.get_tag()],
|
||||
Annotated[LoRA_FLUX_Diffusers_Config, LoRA_FLUX_Diffusers_Config.get_tag()],
|
||||
# ControlLoRA - diffusers format
|
||||
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
|
||||
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
|
||||
|
||||
@@ -13,7 +13,7 @@ from invokeai.backend.util import InvokeAILogger
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_aitoolkit_format(
|
||||
state_dict: dict[str, Any],
|
||||
state_dict: dict[str | int, Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
if metadata:
|
||||
@@ -23,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format(
|
||||
return False
|
||||
return software.get("name") == "ai-toolkit"
|
||||
# metadata got lost somewhere
|
||||
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
|
||||
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str))
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -25,7 +25,9 @@ def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool:
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
|
||||
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
|
||||
all_keys_match = all(
|
||||
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str)
|
||||
)
|
||||
|
||||
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
|
||||
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)
|
||||
|
||||
@@ -9,14 +9,16 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_L
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||
all_keys_in_peft_format = all(
|
||||
k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str)
|
||||
)
|
||||
|
||||
# Check if keys use transformer prefix
|
||||
transformer_prefix_keys = [
|
||||
|
||||
@@ -44,7 +44,7 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self
|
||||
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
@@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
if isinstance(k, str)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
|
||||
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
@@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -
|
||||
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
|
||||
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
if isinstance(k, str)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user