diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index be654de751..f1c262df99 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -114,6 +114,7 @@ class ModelFormat(str, Enum): T5Encoder = "t5_encoder" BnbQuantizedLlmInt8b = "bnb_quantized_int8b" BnbQuantizednf4b = "bnb_quantized_nf4b" + GGUFQuantized = "gguf_quantized" class SchedulerPredictionType(str, Enum): @@ -197,7 +198,7 @@ class ModelConfigBase(BaseModel): class CheckpointConfigBase(ModelConfigBase): """Model config for checkpoint-style models.""" - format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field( + format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field( description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint ) config_path: str = Field(description="path to the checkpoint model config file") @@ -363,6 +364,21 @@ class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase): return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}") +class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase): + """Model config for main checkpoint models.""" + + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.format = ModelFormat.GGUFQuantized + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}") + + class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): """Model config for main diffusers models.""" @@ -466,6 +482,7 @@ AnyModelConfig = Annotated[ Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()], + Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()], Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index c7563c2c20..0ec461d99e 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import ( CLIPEmbedDiffusersConfig, MainBnbQuantized4bCheckpointConfig, MainCheckpointConfig, + MainGGUFCheckpointConfig, T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderConfig, VAECheckpointConfig, @@ -35,6 +36,8 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade from invokeai.backend.model_manager.util.model_util import ( convert_bundle_to_flux_transformer_checkpoint, ) +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader +from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher from invokeai.backend.util.silence_warnings import SilenceWarnings try: @@ -204,6 +207,50 @@ class FluxCheckpointModel(ModelLoader): return model +@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized) +class FluxGGUFCheckpointModel(ModelLoader): + """Class to load GGUF main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CheckpointConfigBase): + raise ValueError("Only CheckpointConfigBase models are currently supported here.") + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config) + + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + def _load_from_singlefile( + self, + config: AnyModelConfig, + ) -> AnyModel: + assert isinstance(config, MainGGUFCheckpointConfig) + model_path = Path(config.path) + + with SilenceWarnings(), GGUFPatcher().wrap(): + # Load the state dict and patcher + sd = gguf_sd_loader(model_path) + # Initialize the model + model = Flux(params[config.config_path]) + + # Calculate new state dictionary size and make room in the cache + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) + self._ram_cache.make_room(new_sd_size) + + # Load the state dict into the model + model.load_state_dict(sd, assign=True) + + # Return the model after patching + return model + + @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): """Class to load main models.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 48db855943..80298f0da2 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -30,6 +30,8 @@ from invokeai.backend.model_manager.config import ( SchedulerPredictionType, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta +from invokeai.backend.quantization.gguf.layers import GGUFTensor +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -187,6 +189,7 @@ class ModelProbe(object): if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, + ModelFormat.GGUFQuantized, ]: ckpt_config_path = cls._get_checkpoint_config_path( model_path, @@ -220,7 +223,7 @@ class ModelProbe(object): @classmethod def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): + if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"): raise InvalidModelConfigException(f"{model_path}: unrecognized suffix") if model_path.name == "learned_embeds.bin": @@ -408,6 +411,8 @@ class ModelProbe(object): model = torch.load(model_path, map_location="cpu") assert isinstance(model, dict) return model + elif model_path.suffix.endswith(".gguf"): + return gguf_sd_loader(model_path) else: return safetensors.torch.load_file(model_path) @@ -477,6 +482,8 @@ class CheckpointProbeBase(ProbeBase): or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict ): return ModelFormat.BnbQuantizednf4b + elif any(isinstance(v, GGUFTensor) for v in state_dict.values()): + return ModelFormat.GGUFQuantized return ModelFormat("checkpoint") def get_variant_type(self) -> ModelVariantType: diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index fd904ac335..9e89572977 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -8,6 +8,8 @@ import safetensors import torch from picklescan.scanner import scan_file_path +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader + def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]: checkpoint = {} @@ -54,7 +56,10 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str scan_result = scan_file_path(path) if scan_result.infected_files != 0: raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') - checkpoint = torch.load(path, map_location=torch.device("meta")) + if str(path).endswith(".gguf"): + checkpoint = gguf_sd_loader(Path(path)) + else: + checkpoint = torch.load(path, map_location=torch.device("meta")) return checkpoint diff --git a/invokeai/backend/quantization/gguf/layers.py b/invokeai/backend/quantization/gguf/layers.py new file mode 100644 index 0000000000..c7a8b84d2a --- /dev/null +++ b/invokeai/backend/quantization/gguf/layers.py @@ -0,0 +1,151 @@ +# Largely based on https://github.com/city96/ComfyUI-GGUF + +from typing import Optional, Union + +import gguf +from torch import Tensor, device, dtype, float32, nn, zeros_like + +from invokeai.backend.quantization.gguf.utils import dequantize_tensor, is_quantized + +PATCH_TYPES = Union[list[Tensor], tuple[Tensor]] + + +class GGUFTensor(Tensor): + """ + Main tensor-like class for storing quantized weights + """ + + def __init__(self, *args, tensor_type, tensor_shape, patches=None, **kwargs): + super().__init__() + self.tensor_type = tensor_type + self.tensor_shape = tensor_shape + self.patches = patches or [] + + def __new__(cls, *args, tensor_type, tensor_shape, patches=None, **kwargs): + return super().__new__(cls, *args, **kwargs) + + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_type = getattr(self, "tensor_type", None) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + new.patches = getattr(self, "patches", []).copy() + return new + + def clone(self, *args, **kwargs): + return self + + def detach(self, *args, **kwargs): + return self + + def copy_(self, *args, **kwargs): + # fixes .weight.copy_ in comfy/clip_model/CLIPTextModel + try: + return super().copy_(*args, **kwargs) + except Exception as e: + print(f"ignoring 'copy_' on tensor: {e}") + + def __deepcopy__(self, *args, **kwargs): + # Intel Arc fix, ref#50 + new = super().__deepcopy__(*args, **kwargs) + new.tensor_type = getattr(self, "tensor_type", None) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + new.patches = getattr(self, "patches", []).copy() + return new + + @property + def shape(self): + if not hasattr(self, "tensor_shape"): + self.tensor_shape = self.size() + return self.tensor_shape + + +class GGUFLayer(nn.Module): + """ + This (should) be responsible for de-quantizing on the fly + """ + + dequant_dtype = None + patch_dtype = None + torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} + + def is_ggml_quantized(self, *, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None): + if weight is None or bias is None: + return False + return is_quantized(weight) or is_quantized(bias) + + def _load_from_state_dict(self, state_dict: dict[str, Tensor], prefix: str, *args, **kwargs): + weight, bias = state_dict.get(f"{prefix}weight", None), state_dict.get(f"{prefix}bias", None) + if self.is_ggml_quantized(weight=weight, bias=bias): + return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def ggml_load_from_state_dict( + self, + state_dict: dict[str, Tensor], + prefix: str, + local_metadata, + strict, + missing_keys: list[str], + unexpected_keys, + error_msgs, + ): + for k, v in state_dict.items(): + if k.endswith("weight"): + self.weight = nn.Parameter(v, requires_grad=False) + elif k.endswith("bias") and v is not None: + self.bias = nn.Parameter(v, requires_grad=False) + else: + missing_keys.append(k) + + def _save_to_state_dict(self, *args, **kwargs): + if self.is_ggml_quantized(): + return self.ggml_save_to_state_dict(*args, **kwargs) + return super()._save_to_state_dict(*args, **kwargs) + + def ggml_save_to_state_dict(self, destination: dict[str, Tensor], prefix: str): + # This is a fake state dict for vram estimation + weight = zeros_like(self.weight, device=device("meta")) + destination[prefix + "weight"] = weight + if self.bias is not None: + bias = zeros_like(self.bias, device=device("meta")) + destination[prefix + "bias"] = bias + return + + def get_weight(self, tensor: Optional[Tensor], dtype: dtype): + if tensor is None: + return + + # dequantize tensor while patches load + weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) + return weight + + def calc_size(self) -> int: + """Get the size of this model in bytes.""" + return self.bias.nelement() * self.bias.element_size() + + def cast_bias_weight( + self, + input: Tensor, + dtype: Optional[dtype] = None, + device: Optional[device] = None, + bias_dtype: Optional[dtype] = None, + ) -> tuple[Tensor, Tensor]: + if dtype is None: + dtype = getattr(input, "dtype", float32) + if dtype is None: + raise ValueError("dtype is required") + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + + bias = self.get_weight(self.bias.to(device), dtype) + if bias is not None: + bias = bias.to(dtype=bias_dtype, device=device, copy=False) + + weight = self.get_weight(self.weight.to(device), dtype) + if weight is not None: + weight = weight.to(dtype=dtype, device=device) + if weight is None or bias is None: + raise ValueError("Weight or bias is None") + return weight, bias diff --git a/invokeai/backend/quantization/gguf/loaders.py b/invokeai/backend/quantization/gguf/loaders.py new file mode 100644 index 0000000000..977d2a1379 --- /dev/null +++ b/invokeai/backend/quantization/gguf/loaders.py @@ -0,0 +1,68 @@ +# Largely based on https://github.com/city96/ComfyUI-GGUF + +from pathlib import Path + +import gguf +import torch + +from invokeai.backend.quantization.gguf.layers import GGUFTensor +from invokeai.backend.quantization.gguf.utils import detect_arch + + +def gguf_sd_loader( + path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16 +) -> dict[str, GGUFTensor]: + """ + Read state dict as fake tensors + """ + reader = gguf.GGUFReader(path) + + prefix_len = len(handle_prefix) + tensor_names = {tensor.name for tensor in reader.tensors} + has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) + + tensors: list[tuple[str, gguf.ReaderTensor]] = [] + for tensor in reader.tensors: + sd_key = tensor_name = tensor.name + if has_prefix: + if not tensor_name.startswith(handle_prefix): + continue + sd_key = tensor_name[prefix_len:] + tensors.append((sd_key, tensor)) + + # detect and verify architecture + compat = None + arch_str = None + arch_field = reader.get_field("general.architecture") + if arch_field is not None: + if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING: + raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}") + arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8") + if arch_str not in {"flux"}: + raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}") + else: + arch_str = detect_arch({val[0] for val in tensors}) + compat = "sd.cpp" + + # main loading loop + state_dict: dict[str, GGUFTensor] = {} + qtype_dict: dict[str, int] = {} + for sd_key, tensor in tensors: + tensor_name = tensor.name + tensor_type_str = str(tensor.tensor_type) + torch_tensor = torch.from_numpy(tensor.data) # mmap + + shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) + # Workaround for stable-diffusion.cpp SDXL detection. + if compat == "sd.cpp" and arch_str == "sdxl": + if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")): + while len(shape) > 2 and shape[-1] == 1: + shape = shape[:-1] + + # add to state dict + if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: + torch_tensor = torch_tensor.view(*shape) + state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) + qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 + + return state_dict diff --git a/invokeai/backend/quantization/gguf/torch_patcher.py b/invokeai/backend/quantization/gguf/torch_patcher.py new file mode 100644 index 0000000000..45771f1e63 --- /dev/null +++ b/invokeai/backend/quantization/gguf/torch_patcher.py @@ -0,0 +1,90 @@ +# Largely based on https://github.com/city96/ComfyUI-GGUF + +from contextlib import contextmanager +from typing import Generator, Optional + +import wrapt +from torch import Tensor, bfloat16, dtype, float16, nn + +from invokeai.backend.quantization.gguf.layers import GGUFLayer + + +class TorchPatcher: + @classmethod + @contextmanager + def wrap(cls) -> Generator[None, None, None]: + # Dictionary to store original torch.nn classes for later restoration + original_classes = {} + try: + # Iterate over _patcher's attributes and replace matching torch.nn classes + for attr_name in dir(cls): + if attr_name.startswith("__"): + continue + # Get the class from _patcher + patcher_class = getattr(cls, attr_name) + + # Check if torch.nn has a class with the same name + if hasattr(nn, attr_name): + # Get the original torch.nn class + original_class = getattr(nn, attr_name) + + # Define a helper function to bind the current patcher_attr for each iteration + def create_patch_function(patcher_attr): + # Return a new patch_class function specific to this patcher_attr + @wrapt.decorator + def patch_class(wrapped, instance, args, kwargs): + # Call the _patcher version of the class + return patcher_attr(*args, **kwargs) + + return patch_class + + # Save the original class for restoration later + original_classes[attr_name] = original_class + + # Apply the patch + setattr(nn, attr_name, create_patch_function(patcher_class)(original_class)) + yield + finally: + # Restore the original torch.nn classes + for attr_name, original_class in original_classes.items(): + setattr(nn, attr_name, original_class) + + +class GGUFPatcher(TorchPatcher): + """ + Dequantize weights on the fly before doing the compute + """ + + class Linear(GGUFLayer, nn.Linear): + def forward(self, input: Tensor) -> Tensor: + weight, bias = self.cast_bias_weight(input) + return nn.functional.linear(input, weight, bias) + + class Conv2d(GGUFLayer, nn.Conv2d): + def forward(self, input: Tensor) -> Tensor: + weight, bias = self.cast_bias_weight(input) + return self._conv_forward(input, weight, bias) + + class Embedding(GGUFLayer, nn.Embedding): + def forward(self, input: Tensor, out_dtype: Optional[dtype] = None) -> Tensor: + output_dtype = out_dtype + if not self.weight: + raise ValueError("Embedding layer must have a weight") + if self.weight.dtype == float16 or self.weight.dtype == bfloat16: + out_dtype = None + weight, _ = self.cast_bias_weight(input, device=input.device, dtype=out_dtype) + return nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ).to(dtype=output_dtype) + + class LayerNorm(GGUFLayer, nn.LayerNorm): + def forward(self, input: Tensor) -> Tensor: + if self.weight is None: + return nn.functional.layer_norm(input, self.normalized_shape, None, None, self.eps) + weight, bias = self.cast_bias_weight(input) + return nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + class GroupNorm(GGUFLayer, nn.GroupNorm): + def forward(self, input: Tensor) -> Tensor: + weight, bias = self.cast_bias_weight(input) + return nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) diff --git a/invokeai/backend/quantization/gguf/utils.py b/invokeai/backend/quantization/gguf/utils.py new file mode 100644 index 0000000000..2ed8d1ca00 --- /dev/null +++ b/invokeai/backend/quantization/gguf/utils.py @@ -0,0 +1,361 @@ +# Largely based on https://github.com/city96/ComfyUI-GGUF + +from typing import Callable, Optional, Union + +import gguf +import torch + +TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} + +# K Quants # +QK_K = 256 +K_SCALE_SIZE = 12 + +MODEL_DETECTION = ( + ( + "flux", + ( + ("transformer_blocks.0.attn.norm_added_k.weight",), + ("double_blocks.0.img_attn.proj.weight",), + ), + ), +) + + +def get_scale_min(scales: torch.Tensor): + n_blocks = scales.shape[0] + scales = scales.view(torch.uint8) + scales = scales.reshape((n_blocks, 3, 4)) + + d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) + + sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) + min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + +# Legacy Quants # +def dequantize_blocks_Q8_0( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + d, x = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + x = x.view(torch.int8) + return d * x + + +def dequantize_blocks_Q5_1( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = ql | (qh << 4) + return (d * qs) + m + + +def dequantize_blocks_Q5_0( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d, qh, qs = split_block_dims(blocks, 2, 4) + d = d.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return d * qs + + +def dequantize_blocks_Q4_1( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d, m, qs = split_block_dims(blocks, 2, 2) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + +def dequantize_blocks_Q4_0( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return d * qs + + +def dequantize_blocks_BF16( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) + + +def dequantize_blocks_Q6_K( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + ( + ql, + qh, + scales, + d, + ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) + + scales = scales.view(torch.int8).to(dtype) + d = d.view(torch.float16).to(dtype) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = (qh & 0x03).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)).to(torch.int8) - 32 + q = q.reshape((n_blocks, QK_K // 16, -1)) + + return (d * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q5_K( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) + + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = (qh & 0x01).reshape((n_blocks, -1, 32)) + q = ql | (qh << 4) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q4_K( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q3_K( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) + d = d.view(torch.float16).to(dtype) + + lscales, hscales = scales[:, :8], scales[:, 8:] + lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 2, 1) + ) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( + [0, 2, 4, 6], device=d.device, dtype=torch.uint8 + ).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) + scales = scales.to(torch.int8) - 32 + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 + q = ql.to(torch.int8) - (qh << 2).to(torch.int8) + + return (dl * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q2_K( + blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + n_blocks = blocks.shape[0] + + scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) + + shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 + qs = qs.reshape((n_blocks, QK_K // 16, 16)) + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +DEQUANTIZE_FUNCTIONS: dict[ + gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor] +] = { + gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, + gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, + gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, +} + + +def is_torch_compatible(tensor: Optional[torch.Tensor]): + return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES + + +def is_quantized(tensor: torch.Tensor): + return not is_torch_compatible(tensor) + + +def dequantize_tensor( + tensor: torch.Tensor, dtype: torch.dtype, dequant_dtype: Union[torch.dtype, str, None] = None +) -> torch.Tensor: + qtype: Optional[gguf.GGMLQuantizationType] = getattr(tensor, "tensor_type", None) + oshape: torch.Size = getattr(tensor, "tensor_shape", tensor.shape) + if qtype is None: + raise ValueError("This is not a valid quantized tensor") + if qtype in TORCH_COMPATIBLE_QTYPES: + return tensor.to(dtype) + elif qtype in DEQUANTIZE_FUNCTIONS: + dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype + return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) + else: + new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) + return torch.from_numpy(new).to(tensor.device, dtype=dtype) + + +def dequantize( + data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None +): + """ + Dequantize tensor back to usable shape/dtype + """ + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype] + + rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = rows.reshape((n_blocks, type_size)) + blocks = dequantize_blocks(blocks, block_size, type_size, dtype) + return blocks.reshape(oshape) + + +def to_uint32(x: torch.Tensor) -> torch.Tensor: + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def split_block_dims(blocks: torch.Tensor, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] + + +def move_patch_to_device(item: PATCH_TYPES, device: torch.device) -> PATCH_TYPES: + if isinstance(item, torch.Tensor): + return item.to(device, non_blocking=True) + elif isinstance(item, tuple): + if len(item) == 0: + return item + if not isinstance(item[0], torch.Tensor): + raise ValueError("Invalid item") + return tuple(move_patch_to_device(x, device) for x in item) + elif isinstance(item, list): + if len(item) == 0: + return item + if not isinstance(item[0], torch.Tensor): + raise ValueError("Invalid item") + return [move_patch_to_device(x, device) for x in item] + + +def detect_arch(state_dict: dict[str, torch.Tensor]): + for arch, match_lists in MODEL_DETECTION: + for match_list in match_lists: + if all(key in state_dict for key in match_list): + return arch + breakpoint() + raise ValueError("Unknown model architecture!")