Initial GGUF support for flux models

This commit is contained in:
Brandon Rising
2024-09-19 15:00:49 -04:00
committed by Kent Keirsey
parent 950c9f5d0c
commit 2bfb0ddff5
8 changed files with 749 additions and 3 deletions

View File

@@ -114,6 +114,7 @@ class ModelFormat(str, Enum):
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
class SchedulerPredictionType(str, Enum):
@@ -197,7 +198,7 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
@@ -363,6 +364,21 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.GGUFQuantized
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
@@ -466,6 +482,7 @@ AnyModelConfig = Annotated[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],

View File

@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
@@ -35,6 +36,8 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
from invokeai.backend.model_manager.util.model_util import (
convert_bundle_to_flux_transformer_checkpoint,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
@@ -204,6 +207,50 @@ class FluxCheckpointModel(ModelLoader):
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
class FluxGGUFCheckpointModel(ModelLoader):
"""Class to load GGUF main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings(), GGUFPatcher().wrap():
# Load the state dict and patcher
sd = gguf_sd_loader(model_path)
# Initialize the model
model = Flux(params[config.config_path])
# Calculate new state dictionary size and make room in the cache
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
# Load the state dict into the model
model.load_state_dict(sd, assign=True)
# Return the model after patching
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""

View File

@@ -30,6 +30,8 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -187,6 +189,7 @@ class ModelProbe(object):
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
@@ -220,7 +223,7 @@ class ModelProbe(object):
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"):
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
if model_path.name == "learned_embeds.bin":
@@ -408,6 +411,8 @@ class ModelProbe(object):
model = torch.load(model_path, map_location="cpu")
assert isinstance(model, dict)
return model
elif model_path.suffix.endswith(".gguf"):
return gguf_sd_loader(model_path)
else:
return safetensors.torch.load_file(model_path)
@@ -477,6 +482,8 @@ class CheckpointProbeBase(ProbeBase):
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
):
return ModelFormat.BnbQuantizednf4b
elif any(isinstance(v, GGUFTensor) for v in state_dict.values()):
return ModelFormat.GGUFQuantized
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:

View File

@@ -8,6 +8,8 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
checkpoint = {}
@@ -54,7 +56,10 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
if str(path).endswith(".gguf"):
checkpoint = gguf_sd_loader(Path(path))
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -0,0 +1,151 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from typing import Optional, Union
import gguf
from torch import Tensor, device, dtype, float32, nn, zeros_like
from invokeai.backend.quantization.gguf.utils import dequantize_tensor, is_quantized
PATCH_TYPES = Union[list[Tensor], tuple[Tensor]]
class GGUFTensor(Tensor):
"""
Main tensor-like class for storing quantized weights
"""
def __init__(self, *args, tensor_type, tensor_shape, patches=None, **kwargs):
super().__init__()
self.tensor_type = tensor_type
self.tensor_shape = tensor_shape
self.patches = patches or []
def __new__(cls, *args, tensor_type, tensor_shape, patches=None, **kwargs):
return super().__new__(cls, *args, **kwargs)
def to(self, *args, **kwargs):
new = super().to(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
new.patches = getattr(self, "patches", []).copy()
return new
def clone(self, *args, **kwargs):
return self
def detach(self, *args, **kwargs):
return self
def copy_(self, *args, **kwargs):
# fixes .weight.copy_ in comfy/clip_model/CLIPTextModel
try:
return super().copy_(*args, **kwargs)
except Exception as e:
print(f"ignoring 'copy_' on tensor: {e}")
def __deepcopy__(self, *args, **kwargs):
# Intel Arc fix, ref#50
new = super().__deepcopy__(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
new.patches = getattr(self, "patches", []).copy()
return new
@property
def shape(self):
if not hasattr(self, "tensor_shape"):
self.tensor_shape = self.size()
return self.tensor_shape
class GGUFLayer(nn.Module):
"""
This (should) be responsible for de-quantizing on the fly
"""
dequant_dtype = None
patch_dtype = None
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None):
if weight is None or bias is None:
return False
return is_quantized(weight) or is_quantized(bias)
def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs):
weight, bias = state_dict.get(f"{prefix}weight", None), state_dict.get(f"{prefix}bias", None)
if self.is_ggml_quantized(weight=weight, bias=bias):
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def ggml_load_from_state_dict(
self,
state_dict: dict[str, Tensor],
prefix: str,
local_metadata,
strict,
missing_keys: list[str],
unexpected_keys,
error_msgs,
):
for k, v in state_dict.items():
if k.endswith("weight"):
self.weight = nn.Parameter(v, requires_grad=False)
elif k.endswith("bias") and v is not None:
self.bias = nn.Parameter(v, requires_grad=False)
else:
missing_keys.append(k)
def _save_to_state_dict(self, *args, **kwargs):
if self.is_ggml_quantized():
return self.ggml_save_to_state_dict(*args, **kwargs)
return super()._save_to_state_dict(*args, **kwargs)
def ggml_save_to_state_dict(self, destination: dict[str, Tensor], prefix: str):
# This is a fake state dict for vram estimation
weight = zeros_like(self.weight, device=device("meta"))
destination[prefix + "weight"] = weight
if self.bias is not None:
bias = zeros_like(self.bias, device=device("meta"))
destination[prefix + "bias"] = bias
return
def get_weight(self, tensor: Optional[Tensor], dtype: dtype):
if tensor is None:
return
# dequantize tensor while patches load
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
return weight
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
return self.bias.nelement() * self.bias.element_size()
def cast_bias_weight(
self,
input: Tensor,
dtype: Optional[dtype] = None,
device: Optional[device] = None,
bias_dtype: Optional[dtype] = None,
) -> tuple[Tensor, Tensor]:
if dtype is None:
dtype = getattr(input, "dtype", float32)
if dtype is None:
raise ValueError("dtype is required")
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = self.get_weight(self.bias.to(device), dtype)
if bias is not None:
bias = bias.to(dtype=bias_dtype, device=device, copy=False)
weight = self.get_weight(self.weight.to(device), dtype)
if weight is not None:
weight = weight.to(dtype=dtype, device=device)
if weight is None or bias is None:
raise ValueError("Weight or bias is None")
return weight, bias

View File

@@ -0,0 +1,68 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from pathlib import Path
import gguf
import torch
from invokeai.backend.quantization.gguf.layers import GGUFTensor
from invokeai.backend.quantization.gguf.utils import detect_arch
def gguf_sd_loader(
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
) -> dict[str, GGUFTensor]:
"""
Read state dict as fake tensors
"""
reader = gguf.GGUFReader(path)
prefix_len = len(handle_prefix)
tensor_names = {tensor.name for tensor in reader.tensors}
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
tensors: list[tuple[str, gguf.ReaderTensor]] = []
for tensor in reader.tensors:
sd_key = tensor_name = tensor.name
if has_prefix:
if not tensor_name.startswith(handle_prefix):
continue
sd_key = tensor_name[prefix_len:]
tensors.append((sd_key, tensor))
# detect and verify architecture
compat = None
arch_str = None
arch_field = reader.get_field("general.architecture")
if arch_field is not None:
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
if arch_str not in {"flux"}:
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
else:
arch_str = detect_arch({val[0] for val in tensors})
compat = "sd.cpp"
# main loading loop
state_dict: dict[str, GGUFTensor] = {}
qtype_dict: dict[str, int] = {}
for sd_key, tensor in tensors:
tensor_name = tensor.name
tensor_type_str = str(tensor.tensor_type)
torch_tensor = torch.from_numpy(tensor.data) # mmap
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
# Workaround for stable-diffusion.cpp SDXL detection.
if compat == "sd.cpp" and arch_str == "sdxl":
if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
while len(shape) > 2 and shape[-1] == 1:
shape = shape[:-1]
# add to state dict
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
torch_tensor = torch_tensor.view(*shape)
state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
return state_dict

View File

@@ -0,0 +1,90 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from contextlib import contextmanager
from typing import Generator, Optional
import wrapt
from torch import Tensor, bfloat16, dtype, float16, nn
from invokeai.backend.quantization.gguf.layers import GGUFLayer
class TorchPatcher:
@classmethod
@contextmanager
def wrap(cls) -> Generator[None, None, None]:
# Dictionary to store original torch.nn classes for later restoration
original_classes = {}
try:
# Iterate over _patcher's attributes and replace matching torch.nn classes
for attr_name in dir(cls):
if attr_name.startswith("__"):
continue
# Get the class from _patcher
patcher_class = getattr(cls, attr_name)
# Check if torch.nn has a class with the same name
if hasattr(nn, attr_name):
# Get the original torch.nn class
original_class = getattr(nn, attr_name)
# Define a helper function to bind the current patcher_attr for each iteration
def create_patch_function(patcher_attr):
# Return a new patch_class function specific to this patcher_attr
@wrapt.decorator
def patch_class(wrapped, instance, args, kwargs):
# Call the _patcher version of the class
return patcher_attr(*args, **kwargs)
return patch_class
# Save the original class for restoration later
original_classes[attr_name] = original_class
# Apply the patch
setattr(nn, attr_name, create_patch_function(patcher_class)(original_class))
yield
finally:
# Restore the original torch.nn classes
for attr_name, original_class in original_classes.items():
setattr(nn, attr_name, original_class)
class GGUFPatcher(TorchPatcher):
"""
Dequantize weights on the fly before doing the compute
"""
class Linear(GGUFLayer, nn.Linear):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.linear(input, weight, bias)
class Conv2d(GGUFLayer, nn.Conv2d):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return self._conv_forward(input, weight, bias)
class Embedding(GGUFLayer, nn.Embedding):
def forward(self, input: Tensor, out_dtype: Optional[dtype] = None) -> Tensor:
output_dtype = out_dtype
if not self.weight:
raise ValueError("Embedding layer must have a weight")
if self.weight.dtype == float16 or self.weight.dtype == bfloat16:
out_dtype = None
weight, _ = self.cast_bias_weight(input, device=input.device, dtype=out_dtype)
return nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
).to(dtype=output_dtype)
class LayerNorm(GGUFLayer, nn.LayerNorm):
def forward(self, input: Tensor) -> Tensor:
if self.weight is None:
return nn.functional.layer_norm(input, self.normalized_shape, None, None, self.eps)
weight, bias = self.cast_bias_weight(input)
return nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
class GroupNorm(GGUFLayer, nn.GroupNorm):
def forward(self, input: Tensor) -> Tensor:
weight, bias = self.cast_bias_weight(input)
return nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

View File

@@ -0,0 +1,361 @@
# Largely based on https://github.com/city96/ComfyUI-GGUF
from typing import Callable, Optional, Union
import gguf
import torch
TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
# K Quants #
QK_K = 256
K_SCALE_SIZE = 12
MODEL_DETECTION = (
(
"flux",
(
("transformer_blocks.0.attn.norm_added_k.weight",),
("double_blocks.0.img_attn.proj.weight",),
),
),
)
def get_scale_min(scales: torch.Tensor):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
# Legacy Quants #
def dequantize_blocks_Q8_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return d * x
def dequantize_blocks_Q5_1(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = ql | (qh << 4)
return (d * qs) + m
def dequantize_blocks_Q5_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
def dequantize_blocks_Q4_1(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, m, qs = split_block_dims(blocks, 2, 2)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
def dequantize_blocks_Q4_0(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
def dequantize_blocks_BF16(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
def dequantize_blocks_Q6_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
(
ql,
qh,
scales,
d,
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
scales = scales.view(torch.int8).to(dtype)
d = d.view(torch.float16).to(dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4)).to(torch.int8) - 32
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q5_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
q = ql | (qh << 4)
return (d * q - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q4_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q3_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
d = d.view(torch.float16).to(dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 2, 1)
)
lscales = lscales.reshape((n_blocks, 16))
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
).reshape((1, 4, 1))
hscales = hscales.reshape((n_blocks, 16))
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
scales = scales.to(torch.int8) - 32
dl = (d * scales).reshape((n_blocks, 16, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
return (dl * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q2_K(
blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
n_blocks = blocks.shape[0]
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
qs = qs.reshape((n_blocks, QK_K // 16, 16))
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
DEQUANTIZE_FUNCTIONS: dict[
gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor]
] = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
}
def is_torch_compatible(tensor: Optional[torch.Tensor]):
return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
def is_quantized(tensor: torch.Tensor):
return not is_torch_compatible(tensor)
def dequantize_tensor(
tensor: torch.Tensor, dtype: torch.dtype, dequant_dtype: Union[torch.dtype, str, None] = None
) -> torch.Tensor:
qtype: Optional[gguf.GGMLQuantizationType] = getattr(tensor, "tensor_type", None)
oshape: torch.Size = getattr(tensor, "tensor_shape", tensor.shape)
if qtype is None:
raise ValueError("This is not a valid quantized tensor")
if qtype in TORCH_COMPATIBLE_QTYPES:
return tensor.to(dtype)
elif qtype in DEQUANTIZE_FUNCTIONS:
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
else:
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
def dequantize(
data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None
):
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
return blocks.reshape(oshape)
def to_uint32(x: torch.Tensor) -> torch.Tensor:
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def split_block_dims(blocks: torch.Tensor, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
def move_patch_to_device(item: PATCH_TYPES, device: torch.device) -> PATCH_TYPES:
if isinstance(item, torch.Tensor):
return item.to(device, non_blocking=True)
elif isinstance(item, tuple):
if len(item) == 0:
return item
if not isinstance(item[0], torch.Tensor):
raise ValueError("Invalid item")
return tuple(move_patch_to_device(x, device) for x in item)
elif isinstance(item, list):
if len(item) == 0:
return item
if not isinstance(item[0], torch.Tensor):
raise ValueError("Invalid item")
return [move_patch_to_device(x, device) for x in item]
def detect_arch(state_dict: dict[str, torch.Tensor]):
for arch, match_lists in MODEL_DETECTION:
for match_list in match_lists:
if all(key in state_dict for key in match_list):
return arch
breakpoint()
raise ValueError("Unknown model architecture!")