Consolidate the LayerPatching patching modes into a single implementation.

This commit is contained in:
Ryan Dick
2024-12-17 18:33:36 +00:00
parent c37bb6375c
commit e0c899104b
4 changed files with 172 additions and 350 deletions

View File

@@ -301,37 +301,33 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
config = transformer_info.config
assert config is not None
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if config.format in [ModelFormat.Checkpoint]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
)
)
model_is_quantized = False
elif config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(
LayerPatcher.apply_model_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
model_is_quantized = True
else:
raise ValueError(f"Unsupported model format: {config.format}")
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
# Prepare IP-Adapter extensions.
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,

View File

@@ -23,6 +23,8 @@ class LayerPatcher:
prefix: str,
dtype: torch.dtype,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
force_direct_patching: bool = False,
force_sidecar_patching: bool = False,
):
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
module.
@@ -34,7 +36,7 @@ class LayerPatcher:
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LayerPatcher._apply_smart_model_patch(
LayerPatcher.apply_smart_model_patch(
model=model,
prefix=prefix,
patch=patch,
@@ -42,6 +44,8 @@ class LayerPatcher:
original_weights=original_weights,
original_modules=original_modules,
dtype=dtype,
force_direct_patching=force_direct_patching,
force_sidecar_patching=force_sidecar_patching,
)
yield
@@ -60,7 +64,7 @@ class LayerPatcher:
@staticmethod
@torch.no_grad()
def _apply_smart_model_patch(
def apply_smart_model_patch(
model: torch.nn.Module,
prefix: str,
patch: ModelPatchRaw,
@@ -68,6 +72,8 @@ class LayerPatcher:
original_weights: OriginalWeightsStorage,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
force_direct_patching: bool,
force_sidecar_patching: bool,
):
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
@@ -94,15 +100,27 @@ class LayerPatcher:
# Decide whether to use direct patching or a sidecar wrapper.
# Direct patching is preferred, because it results in better runtime speed.
# Reasons to use sidecar patching:
# - The module is quantized, so the caller passed force_sidecar_patching=True.
# - The module is already wrapped in a BaseSidecarWrapper.
# - The module is quantized.
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
# CPU, since this would double the RAM usage)
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
# and that the caller will use the 'apply_model_sidecar_patches' method if the layer is quantized.
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
# forcing full patching even on the CPU?
if isinstance(module, BaseSidecarWrapper) or LayerPatcher._is_any_part_of_layer_on_cpu(module):
use_sidecar_patching = False
if force_direct_patching and force_sidecar_patching:
raise ValueError("Cannot force both direct and sidecar patching.")
elif force_direct_patching:
use_sidecar_patching = False
elif force_sidecar_patching:
use_sidecar_patching = True
elif isinstance(module, BaseSidecarWrapper):
use_sidecar_patching = True
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
use_sidecar_patching = True
if use_sidecar_patching:
LayerPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
@@ -125,89 +143,6 @@ class LayerPatcher:
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
return any(p.device.type == "cpu" for p in layer.parameters())
@staticmethod
@torch.no_grad()
@contextmanager
def apply_model_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more LoRA patches to a model within a context manager.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
CPU RAM, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
LayerPatcher.apply_model_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
del patch
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
cur_param = model.get_parameter(param_key)
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
@staticmethod
@torch.no_grad()
def apply_model_patch(
model: torch.nn.Module,
prefix: str,
patch: ModelPatchRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""Apply a single LoRA patch to a model.
Args:
model (torch.nn.Module): The model to patch.
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
patch (LoRAModelRaw): The LoRA model to patch in.
patch_weight (float): The weight of the LoRA patch.
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
"""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
LayerPatcher._apply_model_layer_patch(
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_weights=original_weights,
)
@staticmethod
@torch.no_grad()
def _apply_model_layer_patch(
@@ -254,89 +189,6 @@ class LayerPatcher:
patch.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
@torch.no_grad()
@contextmanager
def apply_model_sidecar_patches(
model: torch.nn.Module,
patches: Iterable[Tuple[ModelPatchRaw, float]],
prefix: str,
dtype: torch.dtype,
):
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
quantization format.
Args:
model (torch.nn.Module): The model to patch.
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
all at once.
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
different from their compute dtype.
"""
original_modules: dict[str, torch.nn.Module] = {}
try:
for patch, patch_weight in patches:
LayerPatcher._apply_model_sidecar_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
yield
finally:
# Restore original modules.
# Note: This logic assumes no nested modules in original_modules.
for module_key, orig_module in original_modules.items():
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
parent_module = model.get_submodule(module_parent_key)
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
@staticmethod
def _apply_model_sidecar_patch(
model: torch.nn.Module,
patch: ModelPatchRaw,
patch_weight: float,
prefix: str,
original_modules: dict[str, torch.nn.Module],
dtype: torch.dtype,
):
"""Apply a single LoRA sidecar patch to a model."""
if patch_weight == 0:
return
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
# without searching, but some legacy code still uses flattened keys.
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
prefix_len = len(prefix)
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = LayerPatcher._get_submodule(
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
)
LayerPatcher._apply_model_layer_wrapper_patch(
model=model,
module_to_patch=module,
module_to_patch_key=module_key,
patch=layer,
patch_weight=patch_weight,
original_modules=original_modules,
dtype=dtype,
)
@staticmethod
@torch.no_grad()
def _apply_model_layer_wrapper_patch(

View File

@@ -31,12 +31,16 @@ class LoRAExt(ExtensionBase):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
assert isinstance(lora_model, ModelPatchRaw)
LayerPatcher.apply_model_patch(
LayerPatcher.apply_smart_model_patch(
model=unet,
prefix="lora_unet_",
patch=lora_model,
patch_weight=self._weight,
original_weights=original_weights,
original_modules={},
dtype=unet.dtype,
force_direct_patching=True,
force_sidecar_patching=False,
)
del lora_model

View File

@@ -30,160 +30,20 @@ class DummyModuleWithTwoLayers(torch.nn.Module):
@pytest.mark.parametrize(
["device", "num_loras"],
"device",
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
def test_apply_lora_patches(device: str, num_loras: int):
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
correct result, and that model/LoRA tensors are moved between devices as expected.
"""
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on its original device.
assert model.linear_layer_1.weight.data.device.type == device
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
# After unpatching, the original model weights should have been restored on the original device.
assert model.linear_layer_1.weight.data.device.type == device
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_lora_patches_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model.linear_layer_1.weight.data.device.type == "cpu"
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model.linear_layer_1.weight.data.device.type == "cuda"
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
@pytest.mark.parametrize("num_loras", [1, 2])
@pytest.mark.parametrize(
["device", "num_loras"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
def test_apply_lora_sidecar_patches(device: str, num_loras: int):
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
# Initialize num_loras LoRA models with weights of 0.5.
lora_weight = 0.5
lora_models: list[tuple[ModelPatchRaw, float]] = []
for _ in range(num_loras):
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
# Patch the model and run inference during the patch.
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_during_patch = model(input)
# Run inference after unpatching.
output_after_patch = model(input)
# Check that the output before patching is different from the output during patching.
assert not torch.allclose(output_before_patch, output_during_patch)
# Check that the output before patching is the same as the output after patching.
assert torch.allclose(output_before_patch, output_after_patch)
@pytest.mark.parametrize(
["device", "num_loras"],
[
("cpu", 1),
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
("cpu", 2),
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
["force_sidecar_patching", "force_direct_patching"], [(True, False), (False, True), (False, False)]
)
@torch.no_grad()
def test_apply_smart_model_patches(device: str, num_loras: int):
def test_apply_smart_model_patches(
device: str, num_loras: int, force_sidecar_patching: bool, force_direct_patching: bool
):
"""Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly."""
dtype = torch.float16
linear_in_features = 4
@@ -206,12 +66,44 @@ def test_apply_smart_model_patches(device: str, num_loras: int):
lora = ModelPatchRaw(lora_layers)
lora_models.append((lora, lora_weight))
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
# Run inference before patching the model.
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
output_before_patch = model(input)
expect_sidecar_wrappers = device == "cpu"
if force_sidecar_patching:
expect_sidecar_wrappers = True
elif force_direct_patching:
expect_sidecar_wrappers = False
# Patch the model and run inference during the patch.
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
with LayerPatcher.apply_smart_model_patches(
model=model,
patches=lora_models,
prefix="",
dtype=dtype,
force_direct_patching=force_direct_patching,
force_sidecar_patching=force_sidecar_patching,
):
if expect_sidecar_wrappers:
# There should be sidecar wrappers in the model.
assert isinstance(model.linear_layer_1, BaseSidecarWrapper)
else:
# There should be no sidecar wrappers in the model.
assert not isinstance(model.linear_layer_1, BaseSidecarWrapper)
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
# After patching, the patched model should still be on its original device.
assert model.linear_layer_1.weight.data.device.type == device
# After patching, all LoRA layer weights should have been moved back to the cpu.
for lora, _ in lora_models:
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
output_during_patch = model(input)
# Run inference after unpatching.
@@ -320,16 +212,94 @@ def test_all_patching_methods_produce_same_output(num_loras: int):
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
output_lora_patches = model(input)
with LayerPatcher.apply_smart_model_patches(
model=model, patches=lora_models, prefix="", dtype=dtype, force_direct_patching=True
):
output_force_direct = model(input)
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_lora_sidecar_patches = model(input)
with LayerPatcher.apply_smart_model_patches(
model=model, patches=lora_models, prefix="", dtype=dtype, force_sidecar_patching=True
):
output_force_sidecar = model(input)
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
output_smart_lora_patches = model(input)
output_smart = model(input)
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
# differences are tolerable and expected due to the difference between sidecar vs. patching.
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)
assert torch.allclose(output_force_direct, output_force_sidecar, atol=1e-5)
assert torch.allclose(output_force_direct, output_smart, atol=1e-5)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_smart_model_patches_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
with LayerPatcher.apply_smart_model_patches(
model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True
):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model.linear_layer_1.weight.data.device.type == "cpu"
# There should be no sidecar wrappers in the model.
assert not isinstance(model.linear_layer_1, BaseSidecarWrapper)
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model.linear_layer_1.weight.data.device.type == "cuda"
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
def test_apply_smart_model_patches_force_sidecar_and_direct_patching():
"""Test that ModelPatcher.apply_smart_model_patches(..., force_direct_patching=True, force_sidecar_patching=True)
raises an error.
"""
linear_in_features = 4
linear_out_features = 8
lora_rank = 2
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
lora_layers = {
"linear_layer_1": LoRALayer.from_state_dict_values(
values={
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
},
)
}
lora = ModelPatchRaw(lora_layers)
with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."):
with LayerPatcher.apply_smart_model_patches(
model=model,
patches=[(lora, 0.5)],
prefix="",
dtype=torch.float16,
force_direct_patching=True,
force_sidecar_patching=True,
):
pass