Remove no longer used code paths, general cleanup of new dequantization code, update probe

This commit is contained in:
Brandon Rising
2024-09-30 22:50:15 -04:00
committed by Kent Keirsey
parent 7d9f125232
commit 446e2884bc
6 changed files with 4 additions and 477 deletions

View File

@@ -30,7 +30,7 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -482,7 +482,7 @@ class CheckpointProbeBase(ProbeBase):
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
):
return ModelFormat.BnbQuantizednf4b
elif any(isinstance(v, GGUFTensor) for v in state_dict.values()):
elif any(isinstance(v, GGMLTensor) for v in state_dict.values()):
return ModelFormat.GGUFQuantized
return ModelFormat("checkpoint")

View File

@@ -1,174 +0,0 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from typing import Callable, List, Optional, Union
import gguf
import torch
from invokeai.backend.quantization.gguf.utils import dequantize_tensor, is_quantized
PATCH_TYPES = Union[list[torch.Tensor], tuple[torch.Tensor]]
class GGUFTensor(torch.Tensor):
"""
Main tensor-like class for storing quantized weights.
Inherits from torch.Tensor and adds additional attributes.
"""
tensor_type: Union[torch.dtype, gguf.GGMLQuantizationType, None]
tensor_shape: torch.Size
patches: List[Callable[[torch.Tensor], torch.Tensor]]
def __new__(
cls,
data,
tensor_type: Union[torch.dtype, gguf.GGMLQuantizationType],
tensor_shape: torch.Size,
patches: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None,
**kwargs,
):
# Create a new tensor instance using the superclass method
if isinstance(data, torch.Tensor):
tensor = data.as_subclass(cls)
else:
tensor = torch.tensor(data, **kwargs).as_subclass(cls)
# Set the additional attributes
tensor.tensor_type = tensor_type
tensor.tensor_shape = tensor_shape
tensor.patches = patches or []
return tensor
def __init__(
self,
data,
tensor_type: Union[torch.dtype, gguf.GGMLQuantizationType],
tensor_shape: torch.Size,
patches: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None,
**kwargs,
):
# __init__ is not called for torch.Tensor subclasses
pass
def to(self, *args, **kwargs):
# Create a new tensor with the desired type/device and copy attributes
new = super().to(*args, **kwargs)
new = new.as_subclass(GGUFTensor)
new.tensor_type = getattr(self, "tensor_type", self.dtype)
new.tensor_shape = getattr(self, "tensor_shape", self.size())
new.patches = getattr(self, "patches", []).copy()
return new
def clone(self, *args, **kwargs):
return self
def detach(self, *args, **kwargs):
return self
def copy_(self, *args, **kwargs):
# Attempt to copy data into the tensor; handle exceptions gracefully
try:
new = super().copy_(*args, **kwargs)
new = new.as_subclass(GGUFTensor)
new.tensor_type = getattr(self, "tensor_type", self.dtype)
new.tensor_shape = getattr(self, "tensor_shape", self.size())
new.patches = getattr(self, "patches", []).copy()
return new
except Exception as e:
print(f"Ignoring 'copy_' on tensor: {e}")
def __deepcopy__(self, memo):
# Create a deep copy of the tensor and copy attributes
new = super().__deepcopy__(memo)
if isinstance(new, torch.Tensor):
new = new.as_subclass(GGUFTensor)
new.tensor_type = getattr(self, "tensor_type", self.dtype)
new.tensor_shape = getattr(self, "tensor_shape", self.size())
new.patches = getattr(self, "patches", []).copy()
return new
@property
def shape(self):
if not hasattr(self, "tensor_shape"):
self.tensor_shape = self.size()
return self.tensor_shape
class GGUFLayer(torch.nn.Module):
"""
This (should) be responsible for de-quantizing on the fly
"""
dequant_dtype = None
patch_dtype = None
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
def is_ggml_quantized(self, *, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None):
weight = weight if weight is not None else self.weight
bias = bias if bias is not None else self.bias
weight_quantized = is_quantized(weight)
bias_quantized = is_quantized(bias)
return weight_quantized or bias_quantized
def _load_from_state_dict(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs):
weight, bias = state_dict.get(f"{prefix}weight", None), state_dict.get(f"{prefix}bias", None)
if self.is_ggml_quantized(weight=weight, bias=bias):
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def ggml_load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys: list[str],
unexpected_keys,
error_msgs,
):
for k, v in state_dict.items():
if k.endswith("weight"):
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k.endswith("bias") and v is not None:
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
missing_keys.append(k)
def get_weight(self, tensor: Optional[torch.Tensor], dtype: torch.dtype):
if tensor is None:
return
# dequantize tensor while patches load
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
return weight
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
return self.bias.nelement() * self.bias.element_size()
def cast_bias_weight(
self,
input: torch.Tensor,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
bias_dtype: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if dtype is None:
dtype = getattr(input, "dtype", torch.float32)
if dtype is None:
raise ValueError("dtype is required")
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = self.get_weight(self.bias.to(device), dtype)
if bias is not None:
bias = bias.to(dtype=bias_dtype, device=device, copy=False)
weight = self.get_weight(self.weight.to(device), dtype)
if weight is not None:
weight = weight.to(dtype=dtype, device=device)
if weight is None or bias is None:
raise ValueError("Weight or bias is None")
return weight, bias

View File

@@ -1,19 +1,16 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from pathlib import Path
import gguf
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
def gguf_sd_loader(path: Path) -> dict[str, GGMLTensor]:
reader = gguf.GGUFReader(path)
sd: dict[str, GGUFTensor] = {}
sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
@@ -21,62 +18,3 @@ def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
return sd
# def gguf_sd_loader(
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
# ) -> dict[str, GGUFTensor]:
# """
# Read state dict as fake tensors
# """
# reader = gguf.GGUFReader(path)
# prefix_len = len(handle_prefix)
# tensor_names = {tensor.name for tensor in reader.tensors}
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
# for tensor in reader.tensors:
# sd_key = tensor_name = tensor.name
# if has_prefix:
# if not tensor_name.startswith(handle_prefix):
# continue
# sd_key = tensor_name[prefix_len:]
# tensors.append((sd_key, tensor))
# # detect and verify architecture
# compat = None
# arch_str = None
# arch_field = reader.get_field("general.architecture")
# if arch_field is not None:
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
# if arch_str not in {"flux"}:
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
# else:
# arch_str = detect_arch({val[0] for val in tensors})
# compat = "sd.cpp"
# # main loading loop
# state_dict: dict[str, GGUFTensor] = {}
# qtype_dict: dict[str, int] = {}
# for sd_key, tensor in tensors:
# tensor_name = tensor.name
# tensor_type_str = str(tensor.tensor_type)
# torch_tensor = torch.from_numpy(tensor.data) # mmap
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
# # Workaround for stable-diffusion.cpp SDXL detection.
# if compat == "sd.cpp" and arch_str == "sdxl":
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
# while len(shape) > 2 and shape[-1] == 1:
# shape = shape[:-1]
# # add to state dict
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
# torch_tensor = torch_tensor.view(*shape)
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
# return state_dict

View File

@@ -1,96 +0,0 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Optional, Type
import wrapt
from torch import Tensor, bfloat16, dtype, float16, nn
from invokeai.backend.quantization.gguf.layers import GGUFLayer
class TorchPatcher:
@classmethod
@contextmanager
def wrap(cls) -> Generator[None, None, None]:
# Dictionary to store original torch.nn classes for later restoration
original_classes: Dict[str, Type[Any]] = {}
try:
# Iterate over cls's attributes and replace matching torch.nn classes
for attr_name in dir(cls):
if attr_name.startswith("__"):
continue
# Get the class from cls
patcher_class: Type[Any] = getattr(cls, attr_name)
# Check if torch.nn has a class with the same name
if hasattr(nn, attr_name):
# Get the original torch.nn class
original_class: Type[Any] = getattr(nn, attr_name)
# Save the original class for restoration later
original_classes[attr_name] = original_class
# Apply the patch
patched_class = cls.create_patch_function(patcher_class)(original_class)
setattr(nn, attr_name, patched_class)
yield
finally:
# Restore the original torch.nn classes
for attr_name, original_class in original_classes.items():
setattr(nn, attr_name, original_class)
@staticmethod
def create_patch_function(patcher_attr: Type[Any]) -> Callable[[Type[Any]], Type[Any]]:
# Return a new patch_class function specific to this patcher_attr
@wrapt.decorator
def patch_class(
wrapped: Callable[..., Any],
instance: Any,
args: Any,
kwargs: Any,
) -> Any:
# Call the patcher_attr version of the class
return patcher_attr(*args, **kwargs)
return patch_class
class GGUFPatcher(TorchPatcher):
"""
Dequantize weights on the fly before doing the compute
"""
class Linear(GGUFLayer, nn.Linear):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.linear(input, weight, bias)
class Conv2d(GGUFLayer, nn.Conv2d):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return self._conv_forward(input, weight, bias)
class Embedding(GGUFLayer, nn.Embedding):
def forward(self, input: Tensor, out_dtype: Optional[dtype] = None) -> Tensor:
output_dtype = out_dtype
if not self.weight:
raise ValueError("Embedding layer must have a weight")
if self.weight.dtype == float16 or self.weight.dtype == bfloat16:
out_dtype = None
weight, _ = self.cast_bias_weight(input, device=input.device, dtype=out_dtype)
return nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
).to(dtype=output_dtype)
class LayerNorm(GGUFLayer, nn.LayerNorm):
def forward(self, input: Tensor) -> Tensor:
if self.weight is None:
return nn.functional.layer_norm(input, self.normalized_shape, None, None, self.eps)
weight, bias = self.cast_bias_weight(input)
return nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
class GroupNorm(GGUFLayer, nn.GroupNorm):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

View File

@@ -1,17 +0,0 @@
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.layers import GGUFTensor
def test_ggml_tensor():
"""Smoke test that multiplication works on a GGMLTensor."""
weight: GGUFTensor = torch.load("tests/assets/gguf_qweight.pt")
tensor_shape = weight.tensor_shape
tensor_type = weight.tensor_type
data = torch.Tensor(weight.data)
ggml_tensor = GGMLTensor(data, tensor_type, tensor_shape)
ones = torch.ones([1], dtype=torch.float32)
_ = ggml_tensor * ones

View File

@@ -1,124 +0,0 @@
import pytest
import torch
import torch.nn as nn
from invokeai.backend.quantization.gguf.layers import GGUFLayer
from invokeai.backend.quantization.gguf.torch_patcher import TorchPatcher
quantized_sd = {
"linear.weight": torch.load("tests/assets/gguf_qweight.pt"),
"linear.bias": torch.load("tests/assets/gguf_qbias.pt"),
}
class TestGGUFPatcher(TorchPatcher):
class Linear(GGUFLayer, nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.linear(input, weight, bias)
class Test2GGUFPatcher(TorchPatcher):
class Linear(GGUFLayer, nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.linear(input, weight, bias)
# Define a dummy module for testing
class DummyModule(nn.Module):
def __init__(self, device: str = "cpu", dtype: torch.dtype = torch.float32):
super().__init__()
self.linear = nn.Linear(3072, 18432, device=device, dtype=dtype)
def forward(self, x):
x = self.linear(x)
return x
# Test that TorchPatcher patches and unpatches nn.Linear correctly
def test_torch_patcher_patches_nn_linear():
original_linear = nn.Linear
with TorchPatcher.wrap():
# nn.Linear should not be replaced
assert nn.Linear is original_linear
assert nn.Linear is original_linear
# Test that GGUFPatcher patches and unpatches nn.Linear correctly
def test_gguf_patcher_patches_nn_linear():
original_linear = nn.Linear
with TestGGUFPatcher.wrap():
# nn.Linear should be replaced
assert nn.Linear is not original_linear
# Create a linear layer and check its type
linear_layer = nn.Linear(3072, 18432)
assert isinstance(linear_layer, TestGGUFPatcher.Linear)
# nn.Linear should be restored
assert nn.Linear is original_linear
# Test that unpatching restores the original behavior
def test_gguf_patcher_unpatch_restores_behavior():
device = "cpu"
dtype = torch.float32
input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
model = DummyModule(device=device, dtype=dtype)
with pytest.raises(Exception): # noqa: B017
model.load_state_dict(quantized_sd)
original_linear = nn.Linear
with TestGGUFPatcher.wrap():
patched_model = DummyModule(device=device, dtype=dtype)
patched_model.load_state_dict(quantized_sd)
# Will raise if patch is not applied
patched_model(input_tensor)
assert isinstance(nn.Linear(4, 8), TestGGUFPatcher.Linear)
# Ensure nn.Linear is restored
assert nn.Linear is not TestGGUFPatcher.Linear
assert isinstance(nn.Linear(4, 8), original_linear)
# Test that the patched Linear layer behaves as expected
def test_gguf_patcher_linear_layer_behavior():
device = "cpu"
dtype = torch.float32
input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
model = DummyModule(device=device, dtype=dtype)
with pytest.raises(Exception): # noqa: B017
model.load_state_dict(quantized_sd)
with TestGGUFPatcher.wrap():
patched_model = DummyModule(device=device, dtype=dtype)
patched_model.load_state_dict(quantized_sd)
patched_tensor = patched_model(input_tensor)
# After unpatching, run forward and ensure patched classes are still applied
assert torch.equal(patched_tensor, patched_model(input_tensor))
# Test that the TorchPatcher works correctly when nesting contexts
def test_torch_patcher_nested_contexts():
original_linear = nn.Linear
with TestGGUFPatcher.wrap():
# First level patching
first_level_linear = nn.Linear
assert first_level_linear is not original_linear
with Test2GGUFPatcher.wrap():
# Second level patching
second_level_linear = nn.Linear
assert second_level_linear is not first_level_linear
# After exiting inner context, nn.Linear should be restored to first level patch
assert nn.Linear is first_level_linear
# After exiting outer context, nn.Linear should be restored to original
assert nn.Linear is original_linear