Update PartialLayer to work with unquantized / GGML quantized / BnB quantized layers.

This commit is contained in:
Ryan Dick
2025-01-24 15:52:57 +00:00
parent 92c6a7d658
commit 67afa7e339
2 changed files with 19 additions and 1 deletions

View File

@@ -4,6 +4,8 @@ import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
@dataclass
@@ -28,6 +30,18 @@ class PartialLayer(BaseLayerPatch):
self.range = range
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
# HACK(ryand): If the original parameters are in a quantized format that can't be sliced, we replace them with
# dummy tensors on the 'meta' device. This allows sub-layers to access the shapes of the sliced parameters. But,
# of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_parameters.keys():
param = orig_parameters[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
pass
else:
orig_parameters[param_name] = torch.empty(get_param_shape(param), device="meta")
# Slice the original parameters to the specified range.
sliced_parameters: dict[str, torch.Tensor] = {}
for param_name, param_weight in orig_parameters.items():
@@ -47,7 +61,9 @@ class PartialLayer(BaseLayerPatch):
out_params: dict[str, torch.Tensor] = {}
for param_name, param_weight in params.items():
orig_param = orig_parameters[param_name]
out_params[param_name] = torch.zeros_like(orig_param)
out_params[param_name] = torch.zeros(
get_param_shape(orig_param), dtype=param_weight.dtype, device=param_weight.device
)
if param_name == "weight":
out_params[param_name][

View File

@@ -54,7 +54,9 @@ GGML_TENSOR_OP_TABLE = {
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
}
if torch.backends.mps.is_available():