Compare commits

...

15 Commits

Author SHA1 Message Date
Ryan Dick
bc300bc498 Add a compute_dtype field to GGMLTensor. 2024-10-01 19:24:55 +00:00
Ryan Dick
c95a8d2a3c Add unit tests for GGMLTensor. 2024-10-01 18:38:38 +00:00
Ryan Dick
b31ba61d23 Fix type errors in GGMLTensor. 2024-10-01 15:40:35 +00:00
Brandon Rising
579eb8718e Remove no longer used code paths, general cleanup of new dequantization code, update probe 2024-09-30 23:15:19 -04:00
Brandon Rising
b4c5210902 Run ruff and update imports 2024-09-30 23:15:19 -04:00
Brandon Rising
1f58b26e74 Run ruff and fix typing in torch patcher 2024-09-30 23:15:19 -04:00
Brandon Rising
2eba5457da Various updates to gguf performance 2024-09-30 23:15:19 -04:00
Brandon
b1ac6f986e Update invokeai/backend/model_manager/load/model_loaders/flux.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-09-30 23:15:19 -04:00
Brandon Rising
7213fceaa6 Add gguf as a pyproject dependency 2024-09-30 23:15:19 -04:00
Ryan Dick
faae144b35 Get alternative GGUF implementation working... barely. 2024-09-30 23:15:19 -04:00
Ryan Dick
17bf03ab7f Initial experimentation with Tensor-like extension for GGUF. 2024-09-30 23:15:19 -04:00
Lincoln Stein
cc24a0e39f recognize .gguf files when scanning a folder for import 2024-09-30 23:15:19 -04:00
Brandon Rising
1a40b486e4 Run Ruff 2024-09-30 23:15:19 -04:00
Brandon Rising
12f5247caa Add unit tests for torch patcher 2024-09-30 23:15:17 -04:00
Brandon Rising
d43dd97016 Initial GGUF support for flux models 2024-09-30 22:24:11 -04:00
11 changed files with 695 additions and 7 deletions

View File

@@ -213,7 +213,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
cached_weights=cached_weights,
)
)
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
elif config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
# than directly patching the weights, but is agnostic to the quantization format.
exit_stack.enter_context(

View File

@@ -114,6 +114,7 @@ class ModelFormat(str, Enum):
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
class SchedulerPredictionType(str, Enum):
@@ -197,7 +198,7 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
@@ -363,6 +364,21 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.GGUFQuantized
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
@@ -466,6 +482,7 @@ AnyModelConfig = Annotated[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],

View File

@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
@@ -35,6 +36,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
@@ -204,6 +206,42 @@ class FluxCheckpointModel(ModelLoader):
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
class FluxGGUFCheckpointModel(ModelLoader):
"""Class to load GGUF main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
# Load the state dict and patcher
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
model = Flux(params[config.config_path])
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""

View File

@@ -30,6 +30,8 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -187,6 +189,7 @@ class ModelProbe(object):
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
@@ -220,7 +223,7 @@ class ModelProbe(object):
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
if model_path.name == "learned_embeds.bin":
@@ -408,6 +411,8 @@ class ModelProbe(object):
model = torch.load(model_path, map_location="cpu")
assert isinstance(model, dict)
return model
elif model_path.suffix.endswith(".gguf"):
return gguf_sd_loader(model_path, compute_dtype=torch.float32)
else:
return safetensors.torch.load_file(model_path)
@@ -477,6 +482,8 @@ class CheckpointProbeBase(ProbeBase):
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
):
return ModelFormat.BnbQuantizednf4b
elif any(isinstance(v, GGMLTensor) for v in state_dict.values()):
return ModelFormat.GGUFQuantized
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:

View File

@@ -130,7 +130,7 @@ class ModelSearch:
return
for n in file_names:
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt", ".gguf")):
try:
self.model_found(absolute_path / n)
except KeyboardInterrupt:

View File

@@ -8,6 +8,8 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
checkpoint = {}
@@ -54,7 +56,10 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
if str(path).endswith(".gguf"):
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -0,0 +1,152 @@
from typing import overload
import gguf
import torch
from invokeai.backend.quantization.gguf.utils import (
DEQUANTIZE_FUNCTIONS,
TORCH_COMPATIBLE_QTYPES,
dequantize,
)
def dequantize_and_run(func, args, kwargs):
"""A helper function for running math ops on GGMLTensor inputs.
Dequantizes the inputs, and runs the function.
"""
dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
dequantized_kwargs = {
k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
}
return func(*dequantized_args, **dequantized_kwargs)
def apply_to_quantized_tensor(func, args, kwargs):
"""A helper function to apply a function to a quantized GGML tensor, and re-wrap the result in a GGMLTensor.
Assumes that the first argument is a GGMLTensor.
"""
# We expect the first argument to be a GGMLTensor, and all other arguments to be non-GGMLTensors.
ggml_tensor = args[0]
assert isinstance(ggml_tensor, GGMLTensor)
assert all(not isinstance(a, GGMLTensor) for a in args[1:])
assert all(not isinstance(v, GGMLTensor) for v in kwargs.values())
new_data = func(ggml_tensor.quantized_data, *args[1:], **kwargs)
if new_data.dtype != ggml_tensor.quantized_data.dtype:
# This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors.
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
return GGMLTensor(
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype
)
GGML_TENSOR_OP_TABLE = {
# Ops to run on the quantized tensor.
torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore
torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore
# Ops to run on dequantized tensors.
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
}
class GGMLTensor(torch.Tensor):
"""A torch.Tensor sub-class holding a quantized GGML tensor.
The underlying tensor is quantized, but the GGMLTensor class provides a dequantized view of the tensor on-the-fly
when it is used in operations.
"""
@staticmethod
def __new__(
cls,
data: torch.Tensor,
ggml_quantization_type: gguf.GGMLQuantizationType,
tensor_shape: torch.Size,
compute_dtype: torch.dtype,
):
# Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors.
return torch.Tensor._make_wrapper_subclass( # pyright: ignore
cls,
data.shape,
dtype=data.dtype,
layout=data.layout,
device=data.device,
strides=data.stride(),
storage_offset=data.storage_offset(),
)
def __init__(
self,
data: torch.Tensor,
ggml_quantization_type: gguf.GGMLQuantizationType,
tensor_shape: torch.Size,
compute_dtype: torch.dtype,
):
self.quantized_data = data
self._ggml_quantization_type = ggml_quantization_type
# The dequantized shape of the tensor.
self._tensor_shape = tensor_shape
self.compute_dtype = compute_dtype
def __repr__(self, *, tensor_contents=None):
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
@overload
def size(self, dim: None = None) -> torch.Size: ...
@overload
def size(self, dim: int) -> int: ...
def size(self, dim: int | None = None):
"""Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
if dim is not None:
return self._tensor_shape[dim]
return self._tensor_shape
@property
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
return self.size()
@property
def quantized_shape(self) -> torch.Size:
"""The shape of the quantized tensor."""
return self.quantized_data.shape
def requires_grad_(self, mode: bool = True) -> torch.Tensor:
"""The GGMLTensor class is currently only designed for inference (not training). Setting requires_grad to True
is not supported. This method is a no-op.
"""
return self
def get_dequantized_tensor(self):
"""Return the dequantized tensor.
Args:
dtype: The dtype of the dequantized tensor.
"""
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
return self.quantized_data.to(self.compute_dtype)
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
# TODO(ryand): Look into how the dtype param is intended to be used.
return dequantize(
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
).to(self.compute_dtype)
else:
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type)
return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# We will likely hit cases here in the future where a new op is encountered that is not yet supported.
# The new op simply needs to be added to the GGML_TENSOR_OP_TABLE.
if func in GGML_TENSOR_OP_TABLE:
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
return NotImplemented

View File

@@ -0,0 +1,22 @@
from pathlib import Path
import gguf
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
reader = gguf.GGUFReader(path)
sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
)
return sd

View File

@@ -0,0 +1,327 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from typing import Callable, Optional, Union
import gguf
import torch
TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
# K Quants #
QK_K = 256
K_SCALE_SIZE = 12
def get_scale_min(scales: torch.Tensor):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
# Legacy Quants #
def dequantize_blocks_Q8_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return d * x
def dequantize_blocks_Q5_1(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = ql | (qh << 4)
return (d * qs) + m
def dequantize_blocks_Q5_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
def dequantize_blocks_Q4_1(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, m, qs = split_block_dims(blocks, 2, 2)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
def dequantize_blocks_Q4_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
def dequantize_blocks_BF16(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
def dequantize_blocks_Q6_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
(
ql,
qh,
scales,
d,
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
scales = scales.view(torch.int8).to(dtype)
d = d.view(torch.float16).to(dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4)).to(torch.int8) - 32
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q5_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
q = ql | (qh << 4)
return (d * q - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q4_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q3_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
d = d.view(torch.float16).to(dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 2, 1)
)
lscales = lscales.reshape((n_blocks, 16))
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
).reshape((1, 4, 1))
hscales = hscales.reshape((n_blocks, 16))
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
scales = scales.to(torch.int8) - 32
dl = (d * scales).reshape((n_blocks, 16, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
return (dl * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q2_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
qs = qs.reshape((n_blocks, QK_K // 16, 16))
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
DEQUANTIZE_FUNCTIONS: dict[
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
] = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
}
def is_torch_compatible(tensor: Optional[torch.Tensor]):
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
def is_quantized(tensor: torch.Tensor):
return not is_torch_compatible(tensor)
def dequantize_tensor(
tensor: torch.Tensor, dtype: torch.dtype, dequant_dtype: Union[torch.dtype, str, None] = None
) -> torch.Tensor:
qtype: Optional[gguf.GGMLQuantizationType] = getattr(tensor, "tensor_type", None)
oshape: torch.Size = getattr(tensor, "tensor_shape", tensor.shape)
if qtype is None:
raise ValueError("This is not a valid quantized tensor")
if qtype in TORCH_COMPATIBLE_QTYPES:
return tensor.to(dtype)
elif qtype in DEQUANTIZE_FUNCTIONS:
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
if not (dequant_dtype is None or isinstance(dequant_dtype, torch.dtype)):
raise ValueError("dequant_dtype must be a torch.dtype")
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
else:
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
def dequantize(
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
):
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
return blocks.reshape(oshape)
def to_uint32(x: torch.Tensor) -> torch.Tensor:
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def split_block_dims(blocks: torch.Tensor, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]

View File

@@ -39,6 +39,7 @@ dependencies = [
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.27.2",
"gguf==0.10.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
@@ -51,10 +52,10 @@ dependencies = [
"sentencepiece==0.2.0",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2",
"torch==2.4.1",
"torchmetrics==0.11.4",
"torchsde==0.2.6",
"torchvision==0.17.2",
"torchvision==0.19.1",
"transformers==4.41.1",
# Core application dependencies, pinned for reproducible builds.

View File

@@ -0,0 +1,115 @@
import gguf
import pytest
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
def quantize_tensor(data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType) -> GGMLTensor:
"""Quantize a torch.Tensor to a GGMLTensor.
Uses the gguf library's numpy implementation to quantize the tensor.
"""
data_np = data.detach().cpu().numpy()
quantized_np = gguf.quantize(data_np, ggml_quantization_type)
return GGMLTensor(
data=torch.from_numpy(quantized_np),
ggml_quantization_type=ggml_quantization_type,
tensor_shape=data.shape,
compute_dtype=data.dtype,
).to(device=data.device) # type: ignore
@pytest.mark.parametrize(
["device", "x1_quant_type", "x2_quant_type"],
[
# Test with no quantization.
("cpu", None, None),
# Test with Q8_0 quantization.
("cpu", gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q8_0),
("cpu", None, gguf.GGMLQuantizationType.Q8_0),
("cpu", gguf.GGMLQuantizationType.Q8_0, None),
# Test with F16 quantization (i.e. torch-compmatible quantization).
("cpu", gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.F16),
("cpu", None, gguf.GGMLQuantizationType.F16),
("cpu", gguf.GGMLQuantizationType.F16, None),
# Test all of above cases on CUDA.
("cuda", None, None),
# Test with Q8_0 quantization.
("cuda", gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q8_0),
("cuda", None, gguf.GGMLQuantizationType.Q8_0),
("cuda", gguf.GGMLQuantizationType.Q8_0, None),
# Test with F16 quantization (i.e. torch-compmatible quantization).
("cuda", gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.F16),
("cuda", None, gguf.GGMLQuantizationType.F16),
("cuda", gguf.GGMLQuantizationType.F16, None),
],
)
def test_ggml_tensor_multiply(
device: str, x1_quant_type: gguf.GGMLQuantizationType | None, x2_quant_type: gguf.GGMLQuantizationType | None
):
# Skip test if CUDA is not available.
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available.")
generator = torch.Generator().manual_seed(123)
x1 = torch.randn(32, 64, generator=generator).to(device=device)
x2 = torch.randn(32, 64, generator=generator).to(device=device)
# Quantize the tensors.
x1_quantized = quantize_tensor(x1, x1_quant_type) if x1_quant_type is not None else x1
x2_quantized = quantize_tensor(x2, x2_quant_type) if x2_quant_type is not None else x2
# Check devices.
for x in [x1, x2, x1_quantized, x2_quantized]:
assert x.device.type == device
# Perform the multiplication.
result = x1 * x2
result_quantized = x1_quantized * x2_quantized
assert result.shape == result_quantized.shape
assert result.dtype == result_quantized.dtype
assert torch.allclose(result, result_quantized, atol=1e-1)
def test_ggml_tensor_to_dtype_raises_error():
x = torch.randn(32, 64)
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
with pytest.raises(ValueError):
x_quantized.to(dtype=torch.float32)
with pytest.raises(ValueError):
x_quantized.float()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
def test_ggml_tensor_to_device():
x = torch.randn(32, 64)
x_cpu = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
x_gpu = x_cpu.to(device=torch.device("cuda"))
assert x_cpu.device.type == "cpu"
assert x_gpu.device.type == "cuda"
assert torch.allclose(x_cpu.quantized_data, x_gpu.quantized_data.cpu(), atol=1e-5)
def test_ggml_tensor_shape():
x = torch.randn(32, 64)
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
assert x_quantized.shape == x.shape
assert x_quantized.size() == x.size()
def test_ggml_tensor_quantized_shape():
x = torch.randn(32, 64)
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
# This is mainly just a smoke test to confirm that .quantized_shape can be accesses and doesn't hit any weird
# dispatch errors.
assert x_quantized.quantized_shape != x.shape