mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Consolidate the LayerPatching patching modes into a single implementation.
This commit is contained in:
@@ -301,37 +301,33 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
# Determine if the model is quantized.
|
||||
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
||||
# slower inference than direct patching, but is agnostic to the quantization format.
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
model_is_quantized = False
|
||||
elif config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_model_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
model_is_quantized = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {config.format}")
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare IP-Adapter extensions.
|
||||
pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions(
|
||||
pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds,
|
||||
|
||||
@@ -23,6 +23,8 @@ class LayerPatcher:
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
force_direct_patching: bool = False,
|
||||
force_sidecar_patching: bool = False,
|
||||
):
|
||||
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
|
||||
module.
|
||||
@@ -34,7 +36,7 @@ class LayerPatcher:
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LayerPatcher._apply_smart_model_patch(
|
||||
LayerPatcher.apply_smart_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -42,6 +44,8 @@ class LayerPatcher:
|
||||
original_weights=original_weights,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
force_direct_patching=force_direct_patching,
|
||||
force_sidecar_patching=force_sidecar_patching,
|
||||
)
|
||||
|
||||
yield
|
||||
@@ -60,7 +64,7 @@ class LayerPatcher:
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_smart_model_patch(
|
||||
def apply_smart_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: ModelPatchRaw,
|
||||
@@ -68,6 +72,8 @@ class LayerPatcher:
|
||||
original_weights: OriginalWeightsStorage,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
force_direct_patching: bool,
|
||||
force_sidecar_patching: bool,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||
patching or a sidecar wrapper for each module.
|
||||
@@ -94,15 +100,27 @@ class LayerPatcher:
|
||||
# Decide whether to use direct patching or a sidecar wrapper.
|
||||
# Direct patching is preferred, because it results in better runtime speed.
|
||||
# Reasons to use sidecar patching:
|
||||
# - The module is quantized, so the caller passed force_sidecar_patching=True.
|
||||
# - The module is already wrapped in a BaseSidecarWrapper.
|
||||
# - The module is quantized.
|
||||
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||
# CPU, since this would double the RAM usage)
|
||||
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||
# and that the caller will use the 'apply_model_sidecar_patches' method if the layer is quantized.
|
||||
# and that the caller will set force_sidecar_patching=True if the layer is quantized.
|
||||
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||
# forcing full patching even on the CPU?
|
||||
if isinstance(module, BaseSidecarWrapper) or LayerPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
use_sidecar_patching = False
|
||||
if force_direct_patching and force_sidecar_patching:
|
||||
raise ValueError("Cannot force both direct and sidecar patching.")
|
||||
elif force_direct_patching:
|
||||
use_sidecar_patching = False
|
||||
elif force_sidecar_patching:
|
||||
use_sidecar_patching = True
|
||||
elif isinstance(module, BaseSidecarWrapper):
|
||||
use_sidecar_patching = True
|
||||
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
use_sidecar_patching = True
|
||||
|
||||
if use_sidecar_patching:
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
@@ -125,89 +143,6 @@ class LayerPatcher:
|
||||
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_model_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in
|
||||
CPU RAM, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LayerPatcher.apply_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
cur_param = model.get_parameter(param_key)
|
||||
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
prefix (str): A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
patch (LoRAModelRaw): The LoRA model to patch in.
|
||||
patch_weight (float): The weight of the LoRA patch.
|
||||
original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_patch(
|
||||
@@ -254,89 +189,6 @@ class LayerPatcher:
|
||||
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_model_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
||||
quantization format.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
||||
different from their compute dtype.
|
||||
"""
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LayerPatcher._apply_model_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_model_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA sidecar patch to a model."""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_wrapper_patch(
|
||||
|
||||
@@ -31,12 +31,16 @@ class LoRAExt(ExtensionBase):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
LayerPatcher.apply_model_patch(
|
||||
LayerPatcher.apply_smart_model_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
patch_weight=self._weight,
|
||||
original_weights=original_weights,
|
||||
original_modules={},
|
||||
dtype=unet.dtype,
|
||||
force_direct_patching=True,
|
||||
force_sidecar_patching=False,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
|
||||
@@ -30,160 +30,20 @@ class DummyModuleWithTwoLayers(torch.nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_loras"],
|
||||
"device",
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
"cpu",
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_patches(device: str, num_loras: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the
|
||||
correct result, and that model/LoRA tensors are moved between devices as expected.
|
||||
"""
|
||||
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16)
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
|
||||
|
||||
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
for lora, _ in lora_models:
|
||||
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on its original device.
|
||||
assert model.linear_layer_1.weight.data.device.type == device
|
||||
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the original device.
|
||||
assert model.linear_layer_1.weight.data.device.type == device
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_patches_change_device():
|
||||
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||
still behaves correctly.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
|
||||
with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cpu"
|
||||
|
||||
# Move the model to the GPU.
|
||||
assert model.to("cuda")
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the GPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cuda"
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_loras", [1, 2])
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_loras"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
def test_apply_lora_sidecar_patches(device: str, num_loras: int):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype)
|
||||
|
||||
# Initialize num_loras LoRA models with weights of 0.5.
|
||||
lora_weight = 0.5
|
||||
lora_models: list[tuple[ModelPatchRaw, float]] = []
|
||||
for _ in range(num_loras):
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
output_after_patch = model(input)
|
||||
|
||||
# Check that the output before patching is different from the output during patching.
|
||||
assert not torch.allclose(output_before_patch, output_during_patch)
|
||||
|
||||
# Check that the output before patching is the same as the output after patching.
|
||||
assert torch.allclose(output_before_patch, output_after_patch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "num_loras"],
|
||||
[
|
||||
("cpu", 1),
|
||||
pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
("cpu", 2),
|
||||
pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
["force_sidecar_patching", "force_direct_patching"], [(True, False), (False, True), (False, False)]
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_model_patches(device: str, num_loras: int):
|
||||
def test_apply_smart_model_patches(
|
||||
device: str, num_loras: int, force_sidecar_patching: bool, force_direct_patching: bool
|
||||
):
|
||||
"""Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly."""
|
||||
dtype = torch.float16
|
||||
linear_in_features = 4
|
||||
@@ -206,12 +66,44 @@ def test_apply_smart_model_patches(device: str, num_loras: int):
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
lora_models.append((lora, lora_weight))
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras)
|
||||
|
||||
# Run inference before patching the model.
|
||||
input = torch.randn(1, linear_in_features, device=device, dtype=dtype)
|
||||
output_before_patch = model(input)
|
||||
|
||||
expect_sidecar_wrappers = device == "cpu"
|
||||
if force_sidecar_patching:
|
||||
expect_sidecar_wrappers = True
|
||||
elif force_direct_patching:
|
||||
expect_sidecar_wrappers = False
|
||||
|
||||
# Patch the model and run inference during the patch.
|
||||
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model,
|
||||
patches=lora_models,
|
||||
prefix="",
|
||||
dtype=dtype,
|
||||
force_direct_patching=force_direct_patching,
|
||||
force_sidecar_patching=force_sidecar_patching,
|
||||
):
|
||||
if expect_sidecar_wrappers:
|
||||
# There should be sidecar wrappers in the model.
|
||||
assert isinstance(model.linear_layer_1, BaseSidecarWrapper)
|
||||
else:
|
||||
# There should be no sidecar wrappers in the model.
|
||||
assert not isinstance(model.linear_layer_1, BaseSidecarWrapper)
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight)
|
||||
|
||||
# After patching, the patched model should still be on its original device.
|
||||
assert model.linear_layer_1.weight.data.device.type == device
|
||||
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
for lora, _ in lora_models:
|
||||
assert lora.layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora.layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
output_during_patch = model(input)
|
||||
|
||||
# Run inference after unpatching.
|
||||
@@ -320,16 +212,94 @@ def test_all_patching_methods_produce_same_output(num_loras: int):
|
||||
|
||||
input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype)
|
||||
|
||||
with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""):
|
||||
output_lora_patches = model(input)
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model, patches=lora_models, prefix="", dtype=dtype, force_direct_patching=True
|
||||
):
|
||||
output_force_direct = model(input)
|
||||
|
||||
with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_lora_sidecar_patches = model(input)
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model, patches=lora_models, prefix="", dtype=dtype, force_sidecar_patching=True
|
||||
):
|
||||
output_force_sidecar = model(input)
|
||||
|
||||
with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype):
|
||||
output_smart_lora_patches = model(input)
|
||||
output_smart = model(input)
|
||||
|
||||
# Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical
|
||||
# differences are tolerable and expected due to the difference between sidecar vs. patching.
|
||||
assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5)
|
||||
assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5)
|
||||
assert torch.allclose(output_force_direct, output_force_sidecar, atol=1e-5)
|
||||
assert torch.allclose(output_force_direct, output_smart, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
@torch.no_grad()
|
||||
def test_apply_smart_model_patches_change_device():
|
||||
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||
still behaves correctly.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model.linear_layer_1.weight.data.detach().clone()
|
||||
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True
|
||||
):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cpu"
|
||||
|
||||
# There should be no sidecar wrappers in the model.
|
||||
assert not isinstance(model.linear_layer_1, BaseSidecarWrapper)
|
||||
|
||||
# Move the model to the GPU.
|
||||
assert model.to("cuda")
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the GPU.
|
||||
assert model.linear_layer_1.weight.data.device.type == "cuda"
|
||||
torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False)
|
||||
|
||||
|
||||
def test_apply_smart_model_patches_force_sidecar_and_direct_patching():
|
||||
"""Test that ModelPatcher.apply_smart_model_patches(..., force_direct_patching=True, force_sidecar_patching=True)
|
||||
raises an error.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_rank = 2
|
||||
model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer.from_state_dict_values(
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = ModelPatchRaw(lora_layers)
|
||||
with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."):
|
||||
with LayerPatcher.apply_smart_model_patches(
|
||||
model=model,
|
||||
patches=[(lora, 0.5)],
|
||||
prefix="",
|
||||
dtype=torch.float16,
|
||||
force_direct_patching=True,
|
||||
force_sidecar_patching=True,
|
||||
):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user