mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
6 Commits
v5.1.1
...
ryan/gguf-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68d14c68c6 | ||
|
|
776f274ed5 | ||
|
|
01b088712d | ||
|
|
b1999e2b52 | ||
|
|
0a1a0cac28 | ||
|
|
72a63d94fb |
@@ -114,6 +114,7 @@ class ModelFormat(str, Enum):
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@@ -196,7 +197,7 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
@@ -362,6 +363,21 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.GGUFQuantized
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@@ -465,6 +481,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
@@ -35,6 +36,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
@@ -204,6 +206,50 @@ class FluxCheckpointModel(ModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
|
||||
class FluxGGUFCheckpointModel(ModelLoader):
|
||||
"""Class to load GGUF main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainGGUFCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
# Load the state dict and patcher
|
||||
sd = gguf_sd_loader(model_path)
|
||||
# Initialize the model
|
||||
model = Flux(params[config.config_path])
|
||||
|
||||
# Calculate new state dictionary size and make room in the cache
|
||||
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
|
||||
# Load the state dict into the model
|
||||
model.load_state_dict(sd, assign=True)
|
||||
|
||||
# Return the model after patching
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
@@ -26,6 +26,8 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -183,6 +185,7 @@ class ModelProbe(object):
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
@@ -216,7 +219,7 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
|
||||
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
@@ -402,6 +405,8 @@ class ModelProbe(object):
|
||||
model = torch.load(model_path, map_location="cpu")
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
elif model_path.suffix.endswith(".gguf"):
|
||||
return gguf_sd_loader(model_path)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@@ -471,6 +476,8 @@ class CheckpointProbeBase(ProbeBase):
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
elif any(isinstance(v, GGUFTensor) for v in state_dict.values()):
|
||||
return ModelFormat.GGUFQuantized
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
|
||||
@@ -130,7 +130,7 @@ class ModelSearch:
|
||||
return
|
||||
|
||||
for n in file_names:
|
||||
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
|
||||
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt", ".gguf")):
|
||||
try:
|
||||
self.model_found(absolute_path / n)
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@@ -8,6 +8,8 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
checkpoint = {}
|
||||
@@ -54,7 +56,10 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
if str(path).endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(Path(path))
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
||||
95
invokeai/backend/quantization/gguf/ggml_tensor.py
Normal file
95
invokeai/backend/quantization/gguf/ggml_tensor.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.utils import (
|
||||
DEQUANTIZE_FUNCTIONS,
|
||||
TORCH_COMPATIBLE_QTYPES,
|
||||
dequantize,
|
||||
)
|
||||
|
||||
|
||||
def dequantize_and_run(func, args, kwargs):
|
||||
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
|
||||
dequantized_args = [
|
||||
a.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(a, "get_dequantized_tensor") else a for a in args
|
||||
]
|
||||
dequantized_kwargs = {
|
||||
k: v.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(v, "get_dequantized_tensor") else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
return func(*dequantized_args, **dequantized_kwargs)
|
||||
|
||||
|
||||
def apply_to_quantized_tensor(func, args, kwargs):
|
||||
ggml_tensor = args[0]
|
||||
assert isinstance(ggml_tensor, GGMLTensor)
|
||||
new_data = func(ggml_tensor._data, *args[1:], **kwargs)
|
||||
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
|
||||
|
||||
|
||||
GGML_TENSOR_OP_TABLE = {
|
||||
torch.ops.aten.detach.default: apply_to_quantized_tensor,
|
||||
torch.ops.aten._to_copy.default: apply_to_quantized_tensor,
|
||||
# --
|
||||
torch.ops.aten.t.default: dequantize_and_run,
|
||||
torch.ops.aten.addmm.default: dequantize_and_run,
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run,
|
||||
}
|
||||
|
||||
|
||||
class GGMLTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||
return torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
data.shape,
|
||||
dtype=data.dtype,
|
||||
layout=data.layout,
|
||||
device=data.device,
|
||||
strides=data.stride(),
|
||||
storage_offset=data.storage_offset(),
|
||||
)
|
||||
|
||||
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
|
||||
self._data = data
|
||||
self._ggml_quantization_type = ggml_quantization_type
|
||||
# The dequantized shape of the tensor.
|
||||
self._tensor_shape = tensor_shape
|
||||
|
||||
def __repr__(self):
|
||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
||||
|
||||
def size(self):
|
||||
return self._tensor_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.size()
|
||||
|
||||
def requires_grad_(self, requires_grad: bool = True):
|
||||
# TODO(ryand): Think about whether we should set requires_grad on the underlying tensor.
|
||||
return self
|
||||
|
||||
def get_dequantized_tensor(self, dtype: torch.dtype):
|
||||
"""Return the dequantized tensor.
|
||||
|
||||
Args:
|
||||
dtype: The dtype of the dequantized tensor.
|
||||
"""
|
||||
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
|
||||
return self._data.to(dtype)
|
||||
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
|
||||
# TODO(ryand): Look into how the dtype param is intended to be used.
|
||||
return dequantize(
|
||||
data=self._data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
|
||||
).to(dtype)
|
||||
else:
|
||||
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
|
||||
new = gguf.quants.dequantize(self._data.cpu().numpy(), self._ggml_quantization_type)
|
||||
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func in GGML_TENSOR_OP_TABLE:
|
||||
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
|
||||
raise NotImplementedError(f"Unsupported function {func}")
|
||||
151
invokeai/backend/quantization/gguf/layers.py
Normal file
151
invokeai/backend/quantization/gguf/layers.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import gguf
|
||||
from torch import Tensor, device, dtype, float32, nn, zeros_like
|
||||
|
||||
from invokeai.backend.quantization.gguf.utils import dequantize_tensor, is_quantized
|
||||
|
||||
PATCH_TYPES = Union[list[Tensor], tuple[Tensor]]
|
||||
|
||||
|
||||
class GGUFTensor(Tensor):
|
||||
"""
|
||||
Main tensor-like class for storing quantized weights
|
||||
"""
|
||||
|
||||
def __init__(self, *args, tensor_type, tensor_shape, patches=None, **kwargs):
|
||||
super().__init__()
|
||||
self.tensor_type = tensor_type
|
||||
self.tensor_shape = tensor_shape
|
||||
self.patches = patches or []
|
||||
|
||||
def __new__(cls, *args, tensor_type, tensor_shape, patches=None, **kwargs):
|
||||
return super().__new__(cls, *args, **kwargs)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new = super().to(*args, **kwargs)
|
||||
new.tensor_type = getattr(self, "tensor_type", None)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def detach(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def copy_(self, *args, **kwargs):
|
||||
# fixes .weight.copy_ in comfy/clip_model/CLIPTextModel
|
||||
try:
|
||||
return super().copy_(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"ignoring 'copy_' on tensor: {e}")
|
||||
|
||||
def __deepcopy__(self, *args, **kwargs):
|
||||
# Intel Arc fix, ref#50
|
||||
new = super().__deepcopy__(*args, **kwargs)
|
||||
new.tensor_type = getattr(self, "tensor_type", None)
|
||||
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
||||
new.patches = getattr(self, "patches", []).copy()
|
||||
return new
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if not hasattr(self, "tensor_shape"):
|
||||
self.tensor_shape = self.size()
|
||||
return self.tensor_shape
|
||||
|
||||
|
||||
class GGUFLayer(nn.Module):
|
||||
"""
|
||||
This (should) be responsible for de-quantizing on the fly
|
||||
"""
|
||||
|
||||
dequant_dtype = None
|
||||
patch_dtype = None
|
||||
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
||||
|
||||
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None):
|
||||
if weight is None or bias is None:
|
||||
return False
|
||||
return is_quantized(weight) or is_quantized(bias)
|
||||
|
||||
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs):
|
||||
weight, bias = state_dict.get(f"{prefix}weight", None), state_dict.get(f"{prefix}bias", None)
|
||||
if self.is_ggml_quantized(weight=weight, bias=bias):
|
||||
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
def ggml_load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys: list[str],
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
for k, v in state_dict.items():
|
||||
if k.endswith("weight"):
|
||||
self.weight = nn.Parameter(v, requires_grad=False)
|
||||
elif k.endswith("bias") and v is not None:
|
||||
self.bias = nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
missing_keys.append(k)
|
||||
|
||||
def _save_to_state_dict(self, *args, **kwargs):
|
||||
if self.is_ggml_quantized():
|
||||
return self.ggml_save_to_state_dict(*args, **kwargs)
|
||||
return super()._save_to_state_dict(*args, **kwargs)
|
||||
|
||||
def ggml_save_to_state_dict(self, destination: dict[str, Tensor], prefix: str):
|
||||
# This is a fake state dict for vram estimation
|
||||
weight = zeros_like(self.weight, device=device("meta"))
|
||||
destination[prefix + "weight"] = weight
|
||||
if self.bias is not None:
|
||||
bias = zeros_like(self.bias, device=device("meta"))
|
||||
destination[prefix + "bias"] = bias
|
||||
return
|
||||
|
||||
def get_weight(self, tensor: Optional[Tensor], dtype: dtype):
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# dequantize tensor while patches load
|
||||
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get the size of this model in bytes."""
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def cast_bias_weight(
|
||||
self,
|
||||
input: Tensor,
|
||||
dtype: Optional[dtype] = None,
|
||||
device: Optional[device] = None,
|
||||
bias_dtype: Optional[dtype] = None,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if dtype is None:
|
||||
dtype = getattr(input, "dtype", float32)
|
||||
if dtype is None:
|
||||
raise ValueError("dtype is required")
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = self.get_weight(self.bias.to(device), dtype)
|
||||
if bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype, device=device, copy=False)
|
||||
|
||||
weight = self.get_weight(self.weight.to(device), dtype)
|
||||
if weight is not None:
|
||||
weight = weight.to(dtype=dtype, device=device)
|
||||
if weight is None or bias is None:
|
||||
raise ValueError("Weight or bias is None")
|
||||
return weight, bias
|
||||
82
invokeai/backend/quantization/gguf/loaders.py
Normal file
82
invokeai/backend/quantization/gguf/loaders.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
sd: dict[str, GGUFTensor] = {}
|
||||
for tensor in reader.tensors:
|
||||
torch_tensor = torch.from_numpy(tensor.data)
|
||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
|
||||
torch_tensor = torch_tensor.view(*shape)
|
||||
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
|
||||
return sd
|
||||
|
||||
|
||||
# def gguf_sd_loader(
|
||||
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
|
||||
# ) -> dict[str, GGUFTensor]:
|
||||
# """
|
||||
# Read state dict as fake tensors
|
||||
# """
|
||||
# reader = gguf.GGUFReader(path)
|
||||
|
||||
# prefix_len = len(handle_prefix)
|
||||
# tensor_names = {tensor.name for tensor in reader.tensors}
|
||||
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||
|
||||
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
|
||||
# for tensor in reader.tensors:
|
||||
# sd_key = tensor_name = tensor.name
|
||||
# if has_prefix:
|
||||
# if not tensor_name.startswith(handle_prefix):
|
||||
# continue
|
||||
# sd_key = tensor_name[prefix_len:]
|
||||
# tensors.append((sd_key, tensor))
|
||||
|
||||
# # detect and verify architecture
|
||||
# compat = None
|
||||
# arch_str = None
|
||||
# arch_field = reader.get_field("general.architecture")
|
||||
# if arch_field is not None:
|
||||
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
|
||||
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
|
||||
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
|
||||
# if arch_str not in {"flux"}:
|
||||
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
|
||||
# else:
|
||||
# arch_str = detect_arch({val[0] for val in tensors})
|
||||
# compat = "sd.cpp"
|
||||
|
||||
# # main loading loop
|
||||
# state_dict: dict[str, GGUFTensor] = {}
|
||||
# qtype_dict: dict[str, int] = {}
|
||||
# for sd_key, tensor in tensors:
|
||||
# tensor_name = tensor.name
|
||||
# tensor_type_str = str(tensor.tensor_type)
|
||||
# torch_tensor = torch.from_numpy(tensor.data) # mmap
|
||||
|
||||
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
# # Workaround for stable-diffusion.cpp SDXL detection.
|
||||
# if compat == "sd.cpp" and arch_str == "sdxl":
|
||||
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
|
||||
# while len(shape) > 2 and shape[-1] == 1:
|
||||
# shape = shape[:-1]
|
||||
|
||||
# # add to state dict
|
||||
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
||||
# torch_tensor = torch_tensor.view(*shape)
|
||||
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
||||
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
||||
|
||||
# return state_dict
|
||||
90
invokeai/backend/quantization/gguf/torch_patcher.py
Normal file
90
invokeai/backend/quantization/gguf/torch_patcher.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
|
||||
import wrapt
|
||||
from torch import Tensor, bfloat16, dtype, float16, nn
|
||||
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFLayer
|
||||
|
||||
|
||||
class TorchPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def wrap(cls) -> Generator[None, None, None]:
|
||||
# Dictionary to store original torch.nn classes for later restoration
|
||||
original_classes = {}
|
||||
try:
|
||||
# Iterate over _patcher's attributes and replace matching torch.nn classes
|
||||
for attr_name in dir(cls):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
# Get the class from _patcher
|
||||
patcher_class = getattr(cls, attr_name)
|
||||
|
||||
# Check if torch.nn has a class with the same name
|
||||
if hasattr(nn, attr_name):
|
||||
# Get the original torch.nn class
|
||||
original_class = getattr(nn, attr_name)
|
||||
|
||||
# Define a helper function to bind the current patcher_attr for each iteration
|
||||
def create_patch_function(patcher_attr):
|
||||
# Return a new patch_class function specific to this patcher_attr
|
||||
@wrapt.decorator
|
||||
def patch_class(wrapped, instance, args, kwargs):
|
||||
# Call the _patcher version of the class
|
||||
return patcher_attr(*args, **kwargs)
|
||||
|
||||
return patch_class
|
||||
|
||||
# Save the original class for restoration later
|
||||
original_classes[attr_name] = original_class
|
||||
|
||||
# Apply the patch
|
||||
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class))
|
||||
yield
|
||||
finally:
|
||||
# Restore the original torch.nn classes
|
||||
for attr_name, original_class in original_classes.items():
|
||||
setattr(nn, attr_name, original_class)
|
||||
|
||||
|
||||
class GGUFPatcher(TorchPatcher):
|
||||
"""
|
||||
Dequantize weights on the fly before doing the compute
|
||||
"""
|
||||
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(GGUFLayer, nn.Conv2d):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
class Embedding(GGUFLayer, nn.Embedding):
|
||||
def forward(self, input: Tensor, out_dtype: Optional[dtype] = None) -> Tensor:
|
||||
output_dtype = out_dtype
|
||||
if not self.weight:
|
||||
raise ValueError("Embedding layer must have a weight")
|
||||
if self.weight.dtype == float16 or self.weight.dtype == bfloat16:
|
||||
out_dtype = None
|
||||
weight, _ = self.cast_bias_weight(input, device=input.device, dtype=out_dtype)
|
||||
return nn.functional.embedding(
|
||||
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
||||
).to(dtype=output_dtype)
|
||||
|
||||
class LayerNorm(GGUFLayer, nn.LayerNorm):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
if self.weight is None:
|
||||
return nn.functional.layer_norm(input, self.normalized_shape, None, None, self.eps)
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
class GroupNorm(GGUFLayer, nn.GroupNorm):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
361
invokeai/backend/quantization/gguf/utils.py
Normal file
361
invokeai/backend/quantization/gguf/utils.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
||||
|
||||
# K Quants #
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
|
||||
MODEL_DETECTION = (
|
||||
(
|
||||
"flux",
|
||||
(
|
||||
("transformer_blocks.0.attn.norm_added_k.weight",),
|
||||
("double_blocks.0.img_attn.proj.weight",),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_scale_min(scales: torch.Tensor):
|
||||
n_blocks = scales.shape[0]
|
||||
scales = scales.view(torch.uint8)
|
||||
scales = scales.reshape((n_blocks, 3, 4))
|
||||
|
||||
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
||||
|
||||
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
||||
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
||||
|
||||
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
||||
|
||||
|
||||
# Legacy Quants #
|
||||
def dequantize_blocks_Q8_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
d, x = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
x = x.view(torch.int8)
|
||||
return d * x
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
||||
|
||||
qs = ql | (qh << 4)
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qh, qs = split_block_dims(blocks, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qs = split_block_dims(blocks, 2, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_BF16(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
||||
|
||||
|
||||
def dequantize_blocks_Q6_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
(
|
||||
ql,
|
||||
qh,
|
||||
scales,
|
||||
d,
|
||||
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
||||
|
||||
scales = scales.view(torch.int8).to(dtype)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
||||
q = (ql | (qh << 4)).to(torch.int8) - 32
|
||||
q = q.reshape((n_blocks, QK_K // 16, -1))
|
||||
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
||||
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
||||
q = ql | (qh << 4)
|
||||
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q3_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 2, 1)
|
||||
)
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
|
||||
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 4, 1))
|
||||
hscales = hscales.reshape((n_blocks, 16))
|
||||
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
|
||||
scales = scales.to(torch.int8) - 32
|
||||
|
||||
dl = (d * scales).reshape((n_blocks, 16, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
|
||||
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
|
||||
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
|
||||
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q2_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
||||
|
||||
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
|
||||
qs = qs.reshape((n_blocks, QK_K // 16, 16))
|
||||
qs = dl * qs - ml
|
||||
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
|
||||
DEQUANTIZE_FUNCTIONS: dict[
|
||||
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
|
||||
] = {
|
||||
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
||||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
||||
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
|
||||
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
||||
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
|
||||
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
||||
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
||||
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
||||
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
||||
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
|
||||
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
|
||||
}
|
||||
|
||||
|
||||
def is_torch_compatible(tensor: Optional[torch.Tensor]):
|
||||
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def is_quantized(tensor: torch.Tensor):
|
||||
return not is_torch_compatible(tensor)
|
||||
|
||||
|
||||
def dequantize_tensor(
|
||||
tensor: torch.Tensor, dtype: torch.dtype, dequant_dtype: Union[torch.dtype, str, None] = None
|
||||
) -> torch.Tensor:
|
||||
qtype: Optional[gguf.GGMLQuantizationType] = getattr(tensor, "tensor_type", None)
|
||||
oshape: torch.Size = getattr(tensor, "tensor_shape", tensor.shape)
|
||||
if qtype is None:
|
||||
raise ValueError("This is not a valid quantized tensor")
|
||||
if qtype in TORCH_COMPATIBLE_QTYPES:
|
||||
return tensor.to(dtype)
|
||||
elif qtype in DEQUANTIZE_FUNCTIONS:
|
||||
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
|
||||
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
|
||||
else:
|
||||
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
|
||||
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
|
||||
|
||||
|
||||
def dequantize(
|
||||
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""
|
||||
Dequantize tensor back to usable shape/dtype
|
||||
"""
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype]
|
||||
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = rows.reshape((n_blocks, type_size))
|
||||
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
|
||||
return blocks.reshape(oshape)
|
||||
|
||||
|
||||
def to_uint32(x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
||||
|
||||
|
||||
def split_block_dims(blocks: torch.Tensor, *args):
|
||||
n_max = blocks.shape[1]
|
||||
dims = list(args) + [n_max - sum(args)]
|
||||
return torch.split(blocks, dims, dim=1)
|
||||
|
||||
|
||||
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
|
||||
|
||||
def move_patch_to_device(item: PATCH_TYPES, device: torch.device) -> PATCH_TYPES:
|
||||
if isinstance(item, torch.Tensor):
|
||||
return item.to(device, non_blocking=True)
|
||||
elif isinstance(item, tuple):
|
||||
if len(item) == 0:
|
||||
return item
|
||||
if not isinstance(item[0], torch.Tensor):
|
||||
raise ValueError("Invalid item")
|
||||
return tuple(move_patch_to_device(x, device) for x in item)
|
||||
elif isinstance(item, list):
|
||||
if len(item) == 0:
|
||||
return item
|
||||
if not isinstance(item[0], torch.Tensor):
|
||||
raise ValueError("Invalid item")
|
||||
return [move_patch_to_device(x, device) for x in item]
|
||||
|
||||
|
||||
def detect_arch(state_dict: dict[str, torch.Tensor]):
|
||||
for arch, match_lists in MODEL_DETECTION:
|
||||
for match_list in match_lists:
|
||||
if all(key in state_dict for key in match_list):
|
||||
return arch
|
||||
breakpoint()
|
||||
raise ValueError("Unknown model architecture!")
|
||||
@@ -51,10 +51,10 @@ dependencies = [
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torch==2.4.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"torchvision==0.19.1",
|
||||
"transformers==4.41.1",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
|
||||
BIN
tests/assets/gguf_qbias.pt
Normal file
BIN
tests/assets/gguf_qbias.pt
Normal file
Binary file not shown.
BIN
tests/assets/gguf_qweight.pt
Normal file
BIN
tests/assets/gguf_qweight.pt
Normal file
Binary file not shown.
17
tests/backend/quantization/gguf/test_gguf_tensor.py
Normal file
17
tests/backend/quantization/gguf/test_gguf_tensor.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
||||
|
||||
|
||||
def test_ggml_tensor():
|
||||
"""Smoke test that multiplication works on a GGMLTensor."""
|
||||
weight: GGUFTensor = torch.load("tests/assets/gguf_qweight.pt")
|
||||
tensor_shape = weight.tensor_shape
|
||||
tensor_type = weight.tensor_type
|
||||
data = torch.Tensor(weight.data)
|
||||
|
||||
ggml_tensor = GGMLTensor(data, tensor_type, tensor_shape)
|
||||
ones = torch.ones([1], dtype=torch.float32)
|
||||
|
||||
x = ggml_tensor * ones
|
||||
122
tests/backend/quantization/gguf/test_layers.py
Normal file
122
tests/backend/quantization/gguf/test_layers.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from invokeai.backend.quantization.gguf.layers import GGUFLayer
|
||||
from invokeai.backend.quantization.gguf.torch_patcher import TorchPatcher
|
||||
|
||||
quantized_sd = {
|
||||
"linear.weight": torch.load("tests/assets/gguf_qweight.pt"),
|
||||
"linear.bias": torch.load("tests/assets/gguf_qbias.pt"),
|
||||
}
|
||||
|
||||
|
||||
class TestGGUFPatcher(TorchPatcher):
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class Test2GGUFPatcher(TorchPatcher):
|
||||
class Linear(GGUFLayer, nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight, bias = self.cast_bias_weight(input)
|
||||
return nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
# Define a dummy module for testing
|
||||
class DummyModule(nn.Module):
|
||||
def __init__(self, device: str = "cpu", dtype: torch.dtype = torch.float32):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3072, 18432, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
# Test that TorchPatcher patches and unpatches nn.Linear correctly
|
||||
def test_torch_patcher_patches_nn_linear():
|
||||
original_linear = nn.Linear
|
||||
|
||||
with TorchPatcher.wrap():
|
||||
# nn.Linear should not be replaced
|
||||
assert nn.Linear is original_linear
|
||||
assert nn.Linear is original_linear
|
||||
|
||||
|
||||
# Test that GGUFPatcher patches and unpatches nn.Linear correctly
|
||||
def test_gguf_patcher_patches_nn_linear():
|
||||
original_linear = nn.Linear
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
# nn.Linear should be replaced
|
||||
assert nn.Linear is not original_linear
|
||||
# Create a linear layer and check its type
|
||||
linear_layer = nn.Linear(3072, 18432)
|
||||
assert isinstance(linear_layer, TestGGUFPatcher.Linear)
|
||||
# nn.Linear should be restored
|
||||
assert nn.Linear is original_linear
|
||||
|
||||
|
||||
# Test that unpatching restores the original behavior
|
||||
def test_gguf_patcher_unpatch_restores_behavior():
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
|
||||
input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
|
||||
model = DummyModule(device=device, dtype=dtype)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
model.load_state_dict(quantized_sd)
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
patched_model = DummyModule(device=device, dtype=dtype)
|
||||
patched_model.load_state_dict(quantized_sd)
|
||||
# Will raise if patch is not applied
|
||||
patched_model(input_tensor)
|
||||
|
||||
# Ensure nn.Linear is restored
|
||||
assert nn.Linear is not TestGGUFPatcher.Linear
|
||||
assert isinstance(nn.Linear(4, 8), nn.Linear)
|
||||
|
||||
|
||||
# Test that the patched Linear layer behaves as expected
|
||||
def test_gguf_patcher_linear_layer_behavior():
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
|
||||
input_tensor = torch.randn(1, 3072, device=device, dtype=dtype)
|
||||
model = DummyModule(device=device, dtype=dtype)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
model.load_state_dict(quantized_sd)
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
patched_model = DummyModule(device=device, dtype=dtype)
|
||||
patched_model.load_state_dict(quantized_sd)
|
||||
|
||||
patched_tensor = patched_model(input_tensor)
|
||||
|
||||
# After unpatching, run forward and ensure patched classes are still applied
|
||||
assert torch.equal(patched_tensor, patched_model(input_tensor))
|
||||
|
||||
|
||||
# Test that the TorchPatcher works correctly when nesting contexts
|
||||
def test_torch_patcher_nested_contexts():
|
||||
original_linear = nn.Linear
|
||||
|
||||
with TestGGUFPatcher.wrap():
|
||||
# First level patching
|
||||
first_level_linear = nn.Linear
|
||||
assert first_level_linear is not original_linear
|
||||
|
||||
with Test2GGUFPatcher.wrap():
|
||||
# Second level patching
|
||||
second_level_linear = nn.Linear
|
||||
assert second_level_linear is not first_level_linear
|
||||
|
||||
# After exiting inner context, nn.Linear should be restored to first level patch
|
||||
assert nn.Linear is first_level_linear
|
||||
|
||||
# After exiting outer context, nn.Linear should be restored to original
|
||||
assert nn.Linear is original_linear
|
||||
Reference in New Issue
Block a user