mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 07:48:12 -05:00
Update PartialLayer to work with unquantized / GGML quantized / BnB quantized layers.
This commit is contained in:
@@ -4,6 +4,8 @@ import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.param_shape_utils import get_param_shape
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -28,6 +30,18 @@ class PartialLayer(BaseLayerPatch):
|
||||
self.range = range
|
||||
|
||||
def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
|
||||
# HACK(ryand): If the original parameters are in a quantized format that can't be sliced, we replace them with
|
||||
# dummy tensors on the 'meta' device. This allows sub-layers to access the shapes of the sliced parameters. But,
|
||||
# of course, any sub-layers that need to access the actual values of the parameters will fail.
|
||||
for param_name in orig_parameters.keys():
|
||||
param = orig_parameters[param_name]
|
||||
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
|
||||
pass
|
||||
elif type(param) is GGMLTensor:
|
||||
pass
|
||||
else:
|
||||
orig_parameters[param_name] = torch.empty(get_param_shape(param), device="meta")
|
||||
|
||||
# Slice the original parameters to the specified range.
|
||||
sliced_parameters: dict[str, torch.Tensor] = {}
|
||||
for param_name, param_weight in orig_parameters.items():
|
||||
@@ -47,7 +61,9 @@ class PartialLayer(BaseLayerPatch):
|
||||
out_params: dict[str, torch.Tensor] = {}
|
||||
for param_name, param_weight in params.items():
|
||||
orig_param = orig_parameters[param_name]
|
||||
out_params[param_name] = torch.zeros_like(orig_param)
|
||||
out_params[param_name] = torch.zeros(
|
||||
get_param_shape(orig_param), dtype=param_weight.dtype, device=param_weight.device
|
||||
)
|
||||
|
||||
if param_name == "weight":
|
||||
out_params[param_name][
|
||||
|
||||
@@ -54,7 +54,9 @@ GGML_TENSOR_OP_TABLE = {
|
||||
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
|
||||
}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
|
||||
Reference in New Issue
Block a user