mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Finish consolidating LoRA sidecar wrapper implementations.
This commit is contained in:
@@ -325,11 +325,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
LoRAPatcher.apply_lora_wrapper_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -3,9 +3,6 @@ from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_layer_wrappers import (
|
||||
LoRAConv1dWrapper,
|
||||
LoRAConv2dWrapper,
|
||||
@@ -13,11 +10,6 @@ from invokeai.backend.lora.lora_layer_wrappers import (
|
||||
LoRASidecarWrapper,
|
||||
)
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
@@ -154,7 +146,7 @@ class LoRAPatcher:
|
||||
):
|
||||
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
|
||||
runtime overhead compared to normal LoRA patching, but they enable:
|
||||
- LoRA layers to be applied to quantization format that are quatnized at the tensor level.
|
||||
- LoRA layers to be applied to quantized models
|
||||
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
|
||||
avoid doubling the memory requirements).
|
||||
|
||||
@@ -231,102 +223,6 @@ class LoRAPatcher:
|
||||
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
|
||||
lora_wrapper_layer.add_lora_layer(layer, patch_weight)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
||||
quantization format.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
||||
different from their compute dtype.
|
||||
"""
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_lora_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA sidecar patch to a model."""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
|
||||
# Replace the original module with a LoRASidecarModule if it has not already been done.
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
assert isinstance(module, LoRASidecarModule)
|
||||
lora_sidecar_module = module
|
||||
else:
|
||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
||||
lora_sidecar_module = LoRASidecarModule(module, [])
|
||||
original_modules[module_key] = module
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
||||
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
|
||||
|
||||
# Add the LoRA sidecar layer to the LoRASidecarModule.
|
||||
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||
"""Split a module key into its parent key and module name.
|
||||
@@ -356,21 +252,6 @@ class LoRAPatcher:
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
@staticmethod
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
# TODO(ryand): Add support for more original layer types and LoRA layer types.
|
||||
if isinstance(orig_layer, torch.nn.Linear) or (
|
||||
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
|
||||
):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
@staticmethod
|
||||
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
|
||||
try:
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
|
||||
|
||||
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
concatenated_lora_layer: ConcatenatedLoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._concatenated_lora_layer = concatenated_lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in self._concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= self._weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert self._concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._concatenated_lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_layer: LoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._lora_layer = lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
||||
if self._lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
||||
x *= self._weight * self._lora_layer.scale()
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRASidecarModule(torch.nn.Module):
|
||||
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
|
||||
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self.orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.orig_module(input)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(input)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._orig_module.to(device=device, dtype=dtype)
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -1,49 +0,0 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
|
||||
|
||||
def test_concatenated_lora_linear_sidecar_layer():
|
||||
"""Test that a ConcatenatedLoRALinearSidecarLayer is equivalent to patching a linear layer with the ConcatenatedLoRA
|
||||
layer.
|
||||
"""
|
||||
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += (
|
||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||
)
|
||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||
|
||||
# Create a ConcatenatedLoRALinearSidecarLayer.
|
||||
concatenated_lora_linear_sidecar_layer = ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer, weight=1.0)
|
||||
linear_with_sidecar = LoRASidecarModule(linear, [concatenated_lora_linear_sidecar_layer])
|
||||
|
||||
# Run the ConcatenatedLoRA-patched linear layer and the ConcatenatedLoRALinearSidecarLayer and assert they are
|
||||
# equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_sidecar = linear_with_sidecar(input)
|
||||
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)
|
||||
@@ -1,38 +0,0 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_lora_linear_sidecar_layer():
|
||||
"""Test that a LoRALinearSidecarLayer is equivalent to patching a linear layer with the LoRA layer."""
|
||||
|
||||
# Create a linear layer.
|
||||
in_features = 10
|
||||
out_features = 20
|
||||
linear = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
# Create a LoRA layer.
|
||||
rank = 4
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)
|
||||
|
||||
# Patch the LoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale()
|
||||
linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale()
|
||||
# Create a LoRALinearSidecarLayer.
|
||||
lora_linear_sidecar_layer = LoRALinearSidecarLayer(lora_layer, weight=1.0)
|
||||
linear_with_sidecar = LoRASidecarModule(linear, [lora_linear_sidecar_layer])
|
||||
|
||||
# Run the LoRA-patched linear layer and the LoRALinearSidecarLayer and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_sidecar = linear_with_sidecar(input)
|
||||
assert torch.allclose(output_patched, output_sidecar, atol=1e-6)
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_layer_wrappers import LoRALinearWrapper
|
||||
|
||||
@@ -33,3 +34,36 @@ def test_lora_linear_wrapper():
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
|
||||
def test_concatenated_lora_linear_wrapper():
|
||||
# Create a linear layer.
|
||||
in_features = 5
|
||||
sub_layer_out_features = [5, 10, 15]
|
||||
linear = torch.nn.Linear(in_features, sum(sub_layer_out_features))
|
||||
|
||||
# Create a ConcatenatedLoRA layer.
|
||||
rank = 4
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for out_features in sub_layer_out_features:
|
||||
down = torch.randn(rank, in_features)
|
||||
up = torch.randn(out_features, rank)
|
||||
bias = torch.randn(out_features)
|
||||
sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias))
|
||||
concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0)
|
||||
|
||||
# Patch the ConcatenatedLoRA layer into the linear layer.
|
||||
linear_patched = copy.deepcopy(linear)
|
||||
linear_patched.weight.data += (
|
||||
concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale()
|
||||
)
|
||||
linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale()
|
||||
|
||||
# Create a LoRALinearWrapper.
|
||||
lora_wrapped = LoRALinearWrapper(linear, [concatenated_lora_layer], [1.0])
|
||||
|
||||
# Run the ConcatenatedLoRA-patched linear layer and the LoRALinearWrapper and assert they are equal.
|
||||
input = torch.randn(1, in_features)
|
||||
output_patched = linear_patched(input)
|
||||
output_wrapped = lora_wrapped(input)
|
||||
assert torch.allclose(output_patched, output_wrapped, atol=1e-6)
|
||||
|
||||
@@ -109,56 +109,6 @@ def test_apply_lora_patches_change_device():
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
def test_apply_lora_sidecar_patches(device: str, num_layers: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_layers LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[LoRAModelRaw, float]] = []
|
||||
for _ in range(num_layers):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_layers"],
|
||||
[
|
||||
@@ -239,13 +189,9 @@ def test_all_patching_methods_produce_same_output(num_layers: int):
|
||||
with LoRAPatcher.apply_lora_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_sidecar_patches = model(input)
|
||||
|
||||
with LoRAPatcher.apply_lora_wrapper_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_wrapper_patches = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
||||
assert torch.allclose(output_lora_patches, output_lora_wrapper_patches, atol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user