Move quantized weight handling for patch layers up from ConcatenatedLoRALayer to CustomModuleMixin.

This commit is contained in:
Ryan Dick
2025-01-24 17:21:16 +00:00
parent 28514ba59a
commit 5d472ac1b8
2 changed files with 14 additions and 13 deletions

View File

@@ -3,6 +3,8 @@ import copy
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class CustomModuleMixin:
@@ -42,6 +44,18 @@ class CustomModuleMixin:
device: torch.device | None = None,
):
"""Helper function that aggregates the parameters from all patches into a single dict."""
# HACK(ryand): If the original parameters are in a quantized format whose weights can't be accessed, we replace
# them with dummy tensors on the 'meta' device. This allows patch layers to access the shapes of the original
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_params.keys():
param = orig_params[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
pass
else:
orig_params[param_name] = torch.empty(get_param_shape(param), device="meta")
params: dict[str, torch.Tensor] = {}
for patch, patch_weight in patches_and_weights:

View File

@@ -5,7 +5,6 @@ import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
@dataclass
@@ -35,18 +34,6 @@ class ConcatenatedLoRALayer(BaseLayerPatch):
assert len(self.ranges) == len(self.lora_layers)
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
# HACK(ryand): If the original parameters are in a quantized format that can't be sliced, we replace them with
# dummy tensors on the 'meta' device. This allows sub-layers to access the shapes of the sliced parameters. But,
# of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_parameters.keys():
param = orig_parameters[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
pass
else:
orig_parameters[param_name] = torch.empty(get_param_shape(param), device="meta")
out_parameters: dict[str, torch.Tensor] = {}
for lora_layer, range in zip(self.lora_layers, self.ranges, strict=True):