Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader.

This commit is contained in:
Ryan Dick
2024-09-04 15:18:43 +00:00
parent 6dc4baa925
commit de3edf47fb
2 changed files with 27 additions and 41 deletions

View File

@@ -5,6 +5,9 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
@@ -17,6 +20,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.peft.lora import LoRAModelRaw
@@ -45,12 +49,18 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# TODO(ryand): Add conversions for other base models and raise if an unsupported base model is used.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
return LoRAModelRaw.from_state_dict(state_dict=state_dict, dtype=self._torch_dtype)
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:

View File

@@ -1,15 +1,9 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional
import torch
from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.full_layer import FullLayer
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
@@ -36,69 +30,51 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
return model_size
@classmethod
def from_checkpoint(
def from_state_dict(
cls,
file_path: Union[str, Path],
state_dict: Dict[str, torch.Tensor],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = cls._group_state(state_dict)
del state_dict # Delete state_dict so that layers can be gc'd as they are processed.
model = cls(layers={})
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
raise ValueError(f"Unsupported lora format: {layer_key} - {list(values.keys())}")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
# Reduce memory consumption by removing references to layer values that have already been handled.
grouped_state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
layers[layer_key] = layer
return model
return cls(layers=layers)
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: