Update DoRALayer with a custom get_parameters() override that 1) applies alpha scaling to delta_v, and 2) warns if the base model is incompatible.

This commit is contained in:
Ryan Dick
2025-01-24 17:29:42 +00:00
parent 5d472ac1b8
commit e7fb435cc5

View File

@@ -62,9 +62,7 @@ class DoRALayer(LoRALayerBase):
delta_v = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
delta_v = delta_v.reshape(orig_weight.shape)
# TODO(ryand): Should alpha be applied to delta_v here rather than the final diff?
# TODO(ryand): I expect this to fail if the original weight is BnB Quantized. This class shouldn't have to worry
# about that, but we should add a clear error message further up the stack.
delta_v = delta_v * self.scale()
# At this point, out_weight is the unnormalized direction matrix.
out_weight = orig_weight + delta_v
@@ -90,3 +88,25 @@ class DoRALayer(LoRALayerBase):
def calc_size(self) -> int:
return super().calc_size() + calc_tensors_size([self.up, self.down, self.dora_scale])
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
if any(p.device.type == "meta" for p in orig_parameters.values()):
# If any of the original parameters are on the 'meta' device, we assume this is because the base model is in
# a quantization format that doesn't allow easy dequantization.
raise RuntimeError(
"The base model quantization format (likely bitsandbytes) is not supported for DoRA patches."
)
scale = self.scale()
params = {"weight": self.get_weight(orig_parameters["weight"]) * weight}
bias = self.get_bias(orig_parameters.get("bias", None))
if bias is not None:
params["bias"] = bias * (weight * scale)
# Reshape all params to match the original module's shape.
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
if param_weight.shape != orig_param.shape:
params[param_name] = param_weight.reshape(orig_param.shape)
return params