Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader.

This commit is contained in:
Ryan Dick
2024-09-04 15:18:43 +00:00
committed by Kent Keirsey
parent 8518ae9ccb
commit 04b37e64ea
2 changed files with 27 additions and 41 deletions

View File

@@ -1,14 +1,9 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional
import torch
from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.full_layer import FullLayer
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
@@ -16,7 +11,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
@@ -36,69 +30,51 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
return model_size
@classmethod
def from_checkpoint(
def from_state_dict(
cls,
file_path: Union[str, Path],
state_dict: Dict[str, torch.Tensor],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = cls._group_state(state_dict)
del state_dict # Delete state_dict so that layers can be gc'd as they are processed.
model = cls(layers={})
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
raise ValueError(f"Unsupported lora format: {layer_key} - {list(values.keys())}")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
# Reduce memory consumption by removing references to layer values that have already been handled.
grouped_state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
layers[layer_key] = layer
return model
return cls(layers=layers)
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:

View File

@@ -5,7 +5,11 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
@@ -45,12 +49,18 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# TODO(ryand): Add conversions for other base models and raise if an unsupported base model is used.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
return LoRAModelRaw.from_state_dict(state_dict=state_dict, dtype=self._torch_dtype)
# override
def _get_model_path(self, config: AnyModelConfig) -> Path: