Rename LoRAModelRaw to ModelPatchRaw.

This commit is contained in:
Ryan Dick
2024-12-14 15:40:25 +00:00
parent b820862eab
commit 7fad4c9491
15 changed files with 52 additions and 50 deletions

View File

@@ -20,7 +20,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -66,10 +66,10 @@ class CompelInvocation(BaseInvocation):
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
@@ -162,11 +162,11 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
assert isinstance(lora_model, ModelPatchRaw)
yield (lora_model, lora.weight)
del lora_info
return

View File

@@ -39,7 +39,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
@@ -987,10 +987,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
return

View File

@@ -49,7 +49,7 @@ from invokeai.backend.flux.sampling_utils import (
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
@@ -715,7 +715,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
@@ -723,7 +723,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
loras.append(self.control_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -19,7 +19,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -130,9 +130,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -18,7 +18,7 @@ from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
@@ -193,9 +193,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
) -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -22,7 +22,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
@@ -194,10 +194,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -17,7 +17,7 @@ from invokeai.backend.image_util.segment_anything.segment_anything_pipeline impo
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
@@ -43,7 +43,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
(
TextualInversionModelRaw,
IPAdapter,
LoRAModelRaw,
ModelPatchRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,

View File

@@ -8,7 +8,7 @@ from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlL
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
@@ -43,7 +43,7 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
@@ -81,4 +81,4 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
else:
raise ValueError(f"{layer_key} not expected")
return LoRAModelRaw(layers=layers)
return ModelPatchRaw(layers=layers)

View File

@@ -6,7 +6,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
@@ -30,7 +30,9 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
@@ -215,7 +217,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
return LoRAModelRaw(layers=layers_with_prefix)
return ModelPatchRaw(layers=layers_with_prefix)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:

View File

@@ -9,7 +9,7 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
# Example keys:
@@ -39,7 +39,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
)
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
@@ -71,7 +71,7 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
return ModelPatchRaw(layers=layers)
T = TypeVar("T")

View File

@@ -4,17 +4,17 @@ import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, BaseLayerPatch] = {}
for layer_key, values in grouped_state_dict.items():
layers[layer_key] = any_lora_layer_from_state_dict(values)
return LoRAModelRaw(layers=layers)
return ModelPatchRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:

View File

@@ -7,7 +7,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel):
class ModelPatchRaw(RawModel):
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
self.layers = layers

View File

@@ -5,7 +5,7 @@ import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
@@ -19,7 +19,7 @@ class ModelPatcher:
@contextmanager
def apply_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
@@ -57,7 +57,7 @@ class ModelPatcher:
def apply_model_patch(
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch: ModelPatchRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
@@ -148,7 +148,7 @@ class ModelPatcher:
@contextmanager
def apply_model_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[LoRAModelRaw, float]],
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
dtype: torch.dtype,
):
@@ -189,7 +189,7 @@ class ModelPatcher:
@staticmethod
def _apply_model_sidecar_patch(
model: torch.nn.Module,
patch: LoRAModelRaw,
patch: ModelPatchRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
from diffusers import UNet2DConditionModel
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
@@ -30,7 +30,7 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, LoRAModelRaw)
assert isinstance(lora_model, ModelPatchRaw)
ModelPatcher.apply_model_patch(
model=unet,
prefix="lora_unet_",

View File

@@ -2,7 +2,7 @@ import pytest
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.lora_model_raw import LoRAModelRaw
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.model_patcher import ModelPatcher
@@ -37,7 +37,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
@@ -47,7 +47,7 @@ def test_apply_lora_patches(device: str, num_layers: int):
},
)
}
lora = LoRAModelRaw(lora_layers)
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
@@ -89,7 +89,7 @@ def test_apply_lora_patches_change_device():
},
)
}
lora = LoRAModelRaw(lora_layers)
lora = ModelPatchRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
@@ -128,7 +128,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
@@ -138,7 +138,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int):
},
)
}
lora = LoRAModelRaw(lora_layers)
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
@@ -171,7 +171,7 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
# Initialize num_layers LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[LoRAModelRaw, float]] = []
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_layers):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
@@ -181,7 +181,7 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int):
},
)
}
lora = LoRAModelRaw(lora_layers)
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)