diff --git a/invokeai/backend/patches/layers/partial_layer.py b/invokeai/backend/patches/layers/partial_layer.py index 9d77d92f41..1a3135e58a 100644 --- a/invokeai/backend/patches/layers/partial_layer.py +++ b/invokeai/backend/patches/layers/partial_layer.py @@ -4,6 +4,8 @@ import torch from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase +from invokeai.backend.patches.layers.param_shape_utils import get_param_shape +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor @dataclass @@ -28,6 +30,18 @@ class PartialLayer(BaseLayerPatch): self.range = range def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: + # HACK(ryand): If the original parameters are in a quantized format that can't be sliced, we replace them with + # dummy tensors on the 'meta' device. This allows sub-layers to access the shapes of the sliced parameters. But, + # of course, any sub-layers that need to access the actual values of the parameters will fail. + for param_name in orig_parameters.keys(): + param = orig_parameters[param_name] + if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor: + pass + elif type(param) is GGMLTensor: + pass + else: + orig_parameters[param_name] = torch.empty(get_param_shape(param), device="meta") + # Slice the original parameters to the specified range. sliced_parameters: dict[str, torch.Tensor] = {} for param_name, param_weight in orig_parameters.items(): @@ -47,7 +61,9 @@ class PartialLayer(BaseLayerPatch): out_params: dict[str, torch.Tensor] = {} for param_name, param_weight in params.items(): orig_param = orig_parameters[param_name] - out_params[param_name] = torch.zeros_like(orig_param) + out_params[param_name] = torch.zeros( + get_param_shape(orig_param), dtype=param_weight.dtype, device=param_weight.device + ) if param_name == "weight": out_params[param_name][ diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index 62be2bdb63..d48948dcfa 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -54,7 +54,9 @@ GGML_TENSOR_OP_TABLE = { torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore + torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore + torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore } if torch.backends.mps.is_available():