mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 04:08:06 -05:00
Compare commits
15 Commits
main
...
ryan/gguf-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc300bc498 | ||
|
|
c95a8d2a3c | ||
|
|
b31ba61d23 | ||
|
|
579eb8718e | ||
|
|
b4c5210902 | ||
|
|
1f58b26e74 | ||
|
|
2eba5457da | ||
|
|
b1ac6f986e | ||
|
|
7213fceaa6 | ||
|
|
faae144b35 | ||
|
|
17bf03ab7f | ||
|
|
cc24a0e39f | ||
|
|
1a40b486e4 | ||
|
|
12f5247caa | ||
|
|
d43dd97016 |
@@ -213,7 +213,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
elif config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]:
|
||||
elif config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
|
||||
@@ -114,6 +114,7 @@ class ModelFormat(str, Enum):
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@@ -197,7 +198,7 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
@@ -363,6 +364,21 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.GGUFQuantized
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@@ -466,6 +482,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
@@ -35,6 +36,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
@@ -204,6 +206,42 @@ class FluxCheckpointModel(ModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
|
||||
class FluxGGUFCheckpointModel(ModelLoader):
|
||||
"""Class to load GGUF main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainGGUFCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
# Load the state dict and patcher
|
||||
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
|
||||
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
||||
model = Flux(params[config.config_path])
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
@@ -30,6 +30,8 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -187,6 +189,7 @@ class ModelProbe(object):
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
@@ -220,7 +223,7 @@ class ModelProbe(object):
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
|
||||
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
@@ -408,6 +411,8 @@ class ModelProbe(object):
|
||||
model = torch.load(model_path, map_location="cpu")
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
elif model_path.suffix.endswith(".gguf"):
|
||||
return gguf_sd_loader(model_path, compute_dtype=torch.float32)
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@@ -477,6 +482,8 @@ class CheckpointProbeBase(ProbeBase):
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
elif any(isinstance(v, GGMLTensor) for v in state_dict.values()):
|
||||
return ModelFormat.GGUFQuantized
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
|
||||
@@ -130,7 +130,7 @@ class ModelSearch:
|
||||
return
|
||||
|
||||
for n in file_names:
|
||||
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
|
||||
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt", ".gguf")):
|
||||
try:
|
||||
self.model_found(absolute_path / n)
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@@ -8,6 +8,8 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
checkpoint = {}
|
||||
@@ -54,7 +56,10 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
if str(path).endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
||||
152
invokeai/backend/quantization/gguf/ggml_tensor.py
Normal file
152
invokeai/backend/quantization/gguf/ggml_tensor.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import overload
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.utils import (
|
||||
DEQUANTIZE_FUNCTIONS,
|
||||
TORCH_COMPATIBLE_QTYPES,
|
||||
dequantize,
|
||||
)
|
||||
|
||||
|
||||
def dequantize_and_run(func, args, kwargs):
|
||||
"""A helper function for running math ops on GGMLTensor inputs.
|
||||
|
||||
Dequantizes the inputs, and runs the function.
|
||||
"""
|
||||
dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
|
||||
dequantized_kwargs = {
|
||||
k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
|
||||
}
|
||||
return func(*dequantized_args, **dequantized_kwargs)
|
||||
|
||||
|
||||
def apply_to_quantized_tensor(func, args, kwargs):
|
||||
"""A helper function to apply a function to a quantized GGML tensor, and re-wrap the result in a GGMLTensor.
|
||||
|
||||
Assumes that the first argument is a GGMLTensor.
|
||||
"""
|
||||
# We expect the first argument to be a GGMLTensor, and all other arguments to be non-GGMLTensors.
|
||||
ggml_tensor = args[0]
|
||||
assert isinstance(ggml_tensor, GGMLTensor)
|
||||
assert all(not isinstance(a, GGMLTensor) for a in args[1:])
|
||||
assert all(not isinstance(v, GGMLTensor) for v in kwargs.values())
|
||||
|
||||
new_data = func(ggml_tensor.quantized_data, *args[1:], **kwargs)
|
||||
|
||||
if new_data.dtype != ggml_tensor.quantized_data.dtype:
|
||||
# This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors.
|
||||
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
|
||||
|
||||
return GGMLTensor(
|
||||
new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape, ggml_tensor.compute_dtype
|
||||
)
|
||||
|
||||
|
||||
GGML_TENSOR_OP_TABLE = {
|
||||
# Ops to run on the quantized tensor.
|
||||
torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore
|
||||
torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore
|
||||
# Ops to run on dequantized tensors.
|
||||
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
|
||||
}
|
||||
|
||||
|
||||
class GGMLTensor(torch.Tensor):
|
||||
"""A torch.Tensor sub-class holding a quantized GGML tensor.
|
||||
|
||||
The underlying tensor is quantized, but the GGMLTensor class provides a dequantized view of the tensor on-the-fly
|
||||
when it is used in operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls,
|
||||
data: torch.Tensor,
|
||||
ggml_quantization_type: gguf.GGMLQuantizationType,
|
||||
tensor_shape: torch.Size,
|
||||
compute_dtype: torch.dtype,
|
||||
):
|
||||
# Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors.
|
||||
return torch.Tensor._make_wrapper_subclass( # pyright: ignore
|
||||
cls,
|
||||
data.shape,
|
||||
dtype=data.dtype,
|
||||
layout=data.layout,
|
||||
device=data.device,
|
||||
strides=data.stride(),
|
||||
storage_offset=data.storage_offset(),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor,
|
||||
ggml_quantization_type: gguf.GGMLQuantizationType,
|
||||
tensor_shape: torch.Size,
|
||||
compute_dtype: torch.dtype,
|
||||
):
|
||||
self.quantized_data = data
|
||||
self._ggml_quantization_type = ggml_quantization_type
|
||||
# The dequantized shape of the tensor.
|
||||
self._tensor_shape = tensor_shape
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def __repr__(self, *, tensor_contents=None):
|
||||
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
|
||||
|
||||
@overload
|
||||
def size(self, dim: None = None) -> torch.Size: ...
|
||||
|
||||
@overload
|
||||
def size(self, dim: int) -> int: ...
|
||||
|
||||
def size(self, dim: int | None = None):
|
||||
"""Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
|
||||
if dim is not None:
|
||||
return self._tensor_shape[dim]
|
||||
return self._tensor_shape
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
|
||||
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
|
||||
return self.size()
|
||||
|
||||
@property
|
||||
def quantized_shape(self) -> torch.Size:
|
||||
"""The shape of the quantized tensor."""
|
||||
return self.quantized_data.shape
|
||||
|
||||
def requires_grad_(self, mode: bool = True) -> torch.Tensor:
|
||||
"""The GGMLTensor class is currently only designed for inference (not training). Setting requires_grad to True
|
||||
is not supported. This method is a no-op.
|
||||
"""
|
||||
return self
|
||||
|
||||
def get_dequantized_tensor(self):
|
||||
"""Return the dequantized tensor.
|
||||
|
||||
Args:
|
||||
dtype: The dtype of the dequantized tensor.
|
||||
"""
|
||||
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
|
||||
return self.quantized_data.to(self.compute_dtype)
|
||||
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
|
||||
# TODO(ryand): Look into how the dtype param is intended to be used.
|
||||
return dequantize(
|
||||
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
|
||||
).to(self.compute_dtype)
|
||||
else:
|
||||
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
|
||||
new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type)
|
||||
return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
# We will likely hit cases here in the future where a new op is encountered that is not yet supported.
|
||||
# The new op simply needs to be added to the GGML_TENSOR_OP_TABLE.
|
||||
if func in GGML_TENSOR_OP_TABLE:
|
||||
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
|
||||
return NotImplemented
|
||||
22
invokeai/backend/quantization/gguf/loaders.py
Normal file
22
invokeai/backend/quantization/gguf/loaders.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
sd: dict[str, GGMLTensor] = {}
|
||||
for tensor in reader.tensors:
|
||||
torch_tensor = torch.from_numpy(tensor.data)
|
||||
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
||||
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
|
||||
torch_tensor = torch_tensor.view(*shape)
|
||||
sd[tensor.name] = GGMLTensor(
|
||||
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
|
||||
)
|
||||
return sd
|
||||
327
invokeai/backend/quantization/gguf/utils.py
Normal file
327
invokeai/backend/quantization/gguf/utils.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Largely based on https://github.com/city96/ComfyUI-GGUF
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
|
||||
TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
||||
|
||||
# K Quants #
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
|
||||
|
||||
def get_scale_min(scales: torch.Tensor):
|
||||
n_blocks = scales.shape[0]
|
||||
scales = scales.view(torch.uint8)
|
||||
scales = scales.reshape((n_blocks, 3, 4))
|
||||
|
||||
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
||||
|
||||
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
||||
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
||||
|
||||
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
||||
|
||||
|
||||
# Legacy Quants #
|
||||
def dequantize_blocks_Q8_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
d, x = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
x = x.view(torch.int8)
|
||||
return d * x
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
||||
|
||||
qs = ql | (qh << 4)
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qh, qs = split_block_dims(blocks, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_1(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qs = split_block_dims(blocks, 2, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_0(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_BF16(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
||||
|
||||
|
||||
def dequantize_blocks_Q6_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
(
|
||||
ql,
|
||||
qh,
|
||||
scales,
|
||||
d,
|
||||
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
||||
|
||||
scales = scales.view(torch.int8).to(dtype)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
||||
q = (ql | (qh << 4)).to(torch.int8) - 32
|
||||
q = q.reshape((n_blocks, QK_K // 16, -1))
|
||||
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
||||
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
||||
q = ql | (qh << 4)
|
||||
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q3_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 2, 1)
|
||||
)
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
|
||||
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 4, 1))
|
||||
hscales = hscales.reshape((n_blocks, 16))
|
||||
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
|
||||
scales = scales.to(torch.int8) - 32
|
||||
|
||||
dl = (d * scales).reshape((n_blocks, 16, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
|
||||
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
|
||||
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
|
||||
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q2_K(
|
||||
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
||||
|
||||
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
|
||||
qs = qs.reshape((n_blocks, QK_K // 16, 16))
|
||||
qs = dl * qs - ml
|
||||
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
|
||||
DEQUANTIZE_FUNCTIONS: dict[
|
||||
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
|
||||
] = {
|
||||
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
||||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
||||
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
|
||||
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
||||
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
|
||||
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
||||
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
||||
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
||||
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
||||
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
|
||||
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
|
||||
}
|
||||
|
||||
|
||||
def is_torch_compatible(tensor: Optional[torch.Tensor]):
|
||||
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
||||
|
||||
|
||||
def is_quantized(tensor: torch.Tensor):
|
||||
return not is_torch_compatible(tensor)
|
||||
|
||||
|
||||
def dequantize_tensor(
|
||||
tensor: torch.Tensor, dtype: torch.dtype, dequant_dtype: Union[torch.dtype, str, None] = None
|
||||
) -> torch.Tensor:
|
||||
qtype: Optional[gguf.GGMLQuantizationType] = getattr(tensor, "tensor_type", None)
|
||||
oshape: torch.Size = getattr(tensor, "tensor_shape", tensor.shape)
|
||||
if qtype is None:
|
||||
raise ValueError("This is not a valid quantized tensor")
|
||||
if qtype in TORCH_COMPATIBLE_QTYPES:
|
||||
return tensor.to(dtype)
|
||||
elif qtype in DEQUANTIZE_FUNCTIONS:
|
||||
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
|
||||
if not (dequant_dtype is None or isinstance(dequant_dtype, torch.dtype)):
|
||||
raise ValueError("dequant_dtype must be a torch.dtype")
|
||||
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
|
||||
else:
|
||||
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
|
||||
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
|
||||
|
||||
|
||||
def dequantize(
|
||||
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""
|
||||
Dequantize tensor back to usable shape/dtype
|
||||
"""
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype]
|
||||
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = rows.reshape((n_blocks, type_size))
|
||||
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
|
||||
return blocks.reshape(oshape)
|
||||
|
||||
|
||||
def to_uint32(x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
||||
|
||||
|
||||
def split_block_dims(blocks: torch.Tensor, *args):
|
||||
n_max = blocks.shape[1]
|
||||
dims = list(args) + [n_max - sum(args)]
|
||||
return torch.split(blocks, dims, dim=1)
|
||||
|
||||
|
||||
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||
@@ -39,6 +39,7 @@ dependencies = [
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.27.2",
|
||||
"gguf==0.10.0",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
@@ -51,10 +52,10 @@ dependencies = [
|
||||
"sentencepiece==0.2.0",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torch==2.4.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"torchvision==0.19.1",
|
||||
"transformers==4.41.1",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
|
||||
115
tests/backend/quantization/gguf/test_ggml_tensor.py
Normal file
115
tests/backend/quantization/gguf/test_ggml_tensor.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
def quantize_tensor(data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType) -> GGMLTensor:
|
||||
"""Quantize a torch.Tensor to a GGMLTensor.
|
||||
|
||||
Uses the gguf library's numpy implementation to quantize the tensor.
|
||||
"""
|
||||
data_np = data.detach().cpu().numpy()
|
||||
quantized_np = gguf.quantize(data_np, ggml_quantization_type)
|
||||
return GGMLTensor(
|
||||
data=torch.from_numpy(quantized_np),
|
||||
ggml_quantization_type=ggml_quantization_type,
|
||||
tensor_shape=data.shape,
|
||||
compute_dtype=data.dtype,
|
||||
).to(device=data.device) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["device", "x1_quant_type", "x2_quant_type"],
|
||||
[
|
||||
# Test with no quantization.
|
||||
("cpu", None, None),
|
||||
# Test with Q8_0 quantization.
|
||||
("cpu", gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q8_0),
|
||||
("cpu", None, gguf.GGMLQuantizationType.Q8_0),
|
||||
("cpu", gguf.GGMLQuantizationType.Q8_0, None),
|
||||
# Test with F16 quantization (i.e. torch-compmatible quantization).
|
||||
("cpu", gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.F16),
|
||||
("cpu", None, gguf.GGMLQuantizationType.F16),
|
||||
("cpu", gguf.GGMLQuantizationType.F16, None),
|
||||
# Test all of above cases on CUDA.
|
||||
("cuda", None, None),
|
||||
# Test with Q8_0 quantization.
|
||||
("cuda", gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q8_0),
|
||||
("cuda", None, gguf.GGMLQuantizationType.Q8_0),
|
||||
("cuda", gguf.GGMLQuantizationType.Q8_0, None),
|
||||
# Test with F16 quantization (i.e. torch-compmatible quantization).
|
||||
("cuda", gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.F16),
|
||||
("cuda", None, gguf.GGMLQuantizationType.F16),
|
||||
("cuda", gguf.GGMLQuantizationType.F16, None),
|
||||
],
|
||||
)
|
||||
def test_ggml_tensor_multiply(
|
||||
device: str, x1_quant_type: gguf.GGMLQuantizationType | None, x2_quant_type: gguf.GGMLQuantizationType | None
|
||||
):
|
||||
# Skip test if CUDA is not available.
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available.")
|
||||
|
||||
generator = torch.Generator().manual_seed(123)
|
||||
|
||||
x1 = torch.randn(32, 64, generator=generator).to(device=device)
|
||||
x2 = torch.randn(32, 64, generator=generator).to(device=device)
|
||||
|
||||
# Quantize the tensors.
|
||||
x1_quantized = quantize_tensor(x1, x1_quant_type) if x1_quant_type is not None else x1
|
||||
x2_quantized = quantize_tensor(x2, x2_quant_type) if x2_quant_type is not None else x2
|
||||
|
||||
# Check devices.
|
||||
for x in [x1, x2, x1_quantized, x2_quantized]:
|
||||
assert x.device.type == device
|
||||
|
||||
# Perform the multiplication.
|
||||
result = x1 * x2
|
||||
result_quantized = x1_quantized * x2_quantized
|
||||
|
||||
assert result.shape == result_quantized.shape
|
||||
assert result.dtype == result_quantized.dtype
|
||||
assert torch.allclose(result, result_quantized, atol=1e-1)
|
||||
|
||||
|
||||
def test_ggml_tensor_to_dtype_raises_error():
|
||||
x = torch.randn(32, 64)
|
||||
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
x_quantized.to(dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
x_quantized.float()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
def test_ggml_tensor_to_device():
|
||||
x = torch.randn(32, 64)
|
||||
x_cpu = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
|
||||
|
||||
x_gpu = x_cpu.to(device=torch.device("cuda"))
|
||||
|
||||
assert x_cpu.device.type == "cpu"
|
||||
assert x_gpu.device.type == "cuda"
|
||||
|
||||
assert torch.allclose(x_cpu.quantized_data, x_gpu.quantized_data.cpu(), atol=1e-5)
|
||||
|
||||
|
||||
def test_ggml_tensor_shape():
|
||||
x = torch.randn(32, 64)
|
||||
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
|
||||
|
||||
assert x_quantized.shape == x.shape
|
||||
assert x_quantized.size() == x.size()
|
||||
|
||||
|
||||
def test_ggml_tensor_quantized_shape():
|
||||
x = torch.randn(32, 64)
|
||||
x_quantized = quantize_tensor(x, gguf.GGMLQuantizationType.Q8_0)
|
||||
|
||||
# This is mainly just a smoke test to confirm that .quantized_shape can be accesses and doesn't hit any weird
|
||||
# dispatch errors.
|
||||
assert x_quantized.quantized_shape != x.shape
|
||||
Reference in New Issue
Block a user