mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader.
This commit is contained in:
@@ -1,14 +1,9 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
@@ -16,7 +11,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@@ -36,69 +30,51 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
def from_state_dict(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = cls._group_state(state_dict)
|
||||
del state_dict # Delete state_dict so that layers can be gc'd as they are processed.
|
||||
|
||||
model = cls(layers={})
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {layer_key} - {list(values.keys())}")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
# Reduce memory consumption by removing references to layer values that have already been handled.
|
||||
grouped_state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
return cls(layers=layers)
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
|
||||
@@ -5,7 +5,11 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@@ -45,12 +49,18 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
# Load the state dict from the model file.
|
||||
if model_path.suffix == ".safetensors":
|
||||
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# TODO(ryand): Add conversions for other base models and raise if an unsupported base model is used.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
return LoRAModelRaw.from_state_dict(state_dict=state_dict, dtype=self._torch_dtype)
|
||||
|
||||
# override
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
|
||||
Reference in New Issue
Block a user