Move handling of LoRA scale and patch weight down into the layer patch classes.

This commit is contained in:
Ryan Dick
2024-12-13 23:15:30 +00:00
parent e7e3f7e144
commit fe09f2d27a
7 changed files with 14 additions and 24 deletions

View File

@@ -5,7 +5,7 @@ import torch
class BaseLayerPatch(ABC):
@abstractmethod
def get_parameters(self, orig_module: torch.nn.Module) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
from the returned dict are not updated.
"""

View File

@@ -57,11 +57,12 @@ class LoRALayerBase(BaseLayerPatch):
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
scale = self.scale()
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
params["bias"] = bias * (weight * scale)
return params
@classmethod

View File

@@ -14,9 +14,9 @@ class SetParameterLayer(BaseLayerPatch):
self.weight = weight
self.param_name = param_name
def get_parameters(self, orig_module: torch.nn.Module) -> dict[str, torch.Tensor]:
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
diff = self.weight - orig_module.get_parameter(self.param_name)
return {self.param_name: diff}
return {self.param_name: diff * weight}
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -110,11 +110,6 @@ class LoRAPatcher:
device = first_param.device
dtype = first_param.dtype
# TODO(ryand): Move this down into the patch.
patch_scale = 1.0
if hasattr(patch, "scale"):
patch_scale = patch.scale()
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
@@ -123,7 +118,7 @@ class LoRAPatcher:
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, param_weight in patch.get_parameters(module_to_patch).items():
for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)
@@ -148,7 +143,6 @@ class LoRAPatcher:
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
)
module_param = expanded_weight
param_weight *= patch_weight * patch_scale
module_param += param_weight.to(dtype=dtype)
patch.to(device=TorchDevice.CPU_DEVICE)

View File

@@ -40,7 +40,7 @@ class BaseSidecarWrapper(torch.nn.Module):
for patch, patch_weight in patches_and_weights:
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
# module, this might fail or return incorrect results.
layer_params = patch.get_parameters(self._orig_module)
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
for param_name, param_weight in layer_params.items():
orig_param = self._orig_module.get_parameter(param_name)
@@ -48,15 +48,10 @@ class BaseSidecarWrapper(torch.nn.Module):
if orig_param.shape != param_weight.shape:
param_weight = param_weight.reshape(orig_param.shape)
# TODO(ryand): Move scale handling down into the patch.
scale = 1
if hasattr(patch, "scale"):
scale = patch.scale() # type: ignore
if param_name not in params:
params[param_name] = param_weight * (scale * patch_weight)
params[param_name] = param_weight
else:
params[param_name] += param_weight * (scale * patch_weight)
params[param_name] += param_weight
return params

View File

@@ -107,8 +107,8 @@ def test_lora_layer_get_parameters():
# Create mock original module
orig_module = torch.nn.Linear(in_features, out_features)
params = layer.get_parameters(orig_module)
params = layer.get_parameters(orig_module, weight=1.0)
assert "weight" in params
assert params["weight"].shape == orig_module.weight.shape
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha / rank)
assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha)
assert "bias" not in params # No bias in this case

View File

@@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters():
target_weight = torch.randn(8, 4)
layer = SetParameterLayer(param_name="weight", weight=target_weight)
params = layer.get_parameters(orig_module)
params = layer.get_parameters(orig_module, weight=1.0)
assert len(params) == 1
new_weight = orig_module.weight + params["weight"]
assert torch.allclose(new_weight, target_weight)