# Largely based on https://github.com/city96/ComfyUI-GGUF from typing import Callable, Optional, Union import gguf import torch # should not be a Set until this is resolved: https://github.com/pytorch/pytorch/issues/145761 TORCH_COMPATIBLE_QTYPES = [None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] # K Quants # QK_K = 256 K_SCALE_SIZE = 12 def get_scale_min(scales: torch.Tensor): n_blocks = scales.shape[0] scales = scales.view(torch.uint8) scales = scales.reshape((n_blocks, 3, 4)) d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) # Legacy Quants # def dequantize_blocks_Q8_0( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: d, x = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) x = x.view(torch.int8) return d * x def dequantize_blocks_Q5_1( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape((n_blocks, -1)) qs = ql | (qh << 4) return (d * qs) + m def dequantize_blocks_Q5_0( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape(n_blocks, -1) qs = (ql | (qh << 4)).to(torch.int8) - 16 return d * qs def dequantize_blocks_Q4_1( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, m, qs = split_block_dims(blocks, 2, 2) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qs = (qs & 0x0F).reshape(n_blocks, -1) return (d * qs) + m def dequantize_blocks_Q4_0( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, qs = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 return d * qs def dequantize_blocks_BF16( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) def dequantize_blocks_Q6_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] ( ql, qh, scales, d, ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) scales = scales.view(torch.int8).to(dtype) d = d.view(torch.float16).to(dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( (1, 1, 4, 1) ) qh = (qh & 0x03).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)).to(torch.int8) - 32 q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q5_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape( (1, 1, 8, 1) ) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = (qh & 0x01).reshape((n_blocks, -1, 32)) q = ql | (qh << 4) return (d * q - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q4_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q3_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) d = d.view(torch.float16).to(dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 2, 1) ) lscales = lscales.reshape((n_blocks, 16)) hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( [0, 2, 4, 6], device=d.device, dtype=torch.uint8 ).reshape((1, 4, 1)) hscales = hscales.reshape((n_blocks, 16)) scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) scales = scales.to(torch.int8) - 32 dl = (d * scales).reshape((n_blocks, 16, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( (1, 1, 4, 1) ) qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape( (1, 1, 8, 1) ) ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 q = ql.to(torch.int8) - (qh << 2).to(torch.int8) return (dl * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q2_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 qs = qs.reshape((n_blocks, QK_K // 16, 16)) qs = dl * qs - ml return qs.reshape((n_blocks, -1)) DEQUANTIZE_FUNCTIONS: dict[ gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor] ] = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, } def is_torch_compatible(tensor: Optional[torch.Tensor]): return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES def is_quantized(tensor: torch.Tensor): return not is_torch_compatible(tensor) def dequantize( data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None ): """ Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) n_blocks = rows.numel() // type_size blocks = rows.reshape((n_blocks, type_size)) blocks = dequantize_blocks(blocks, block_size, type_size, dtype) return blocks.reshape(oshape) def to_uint32(x: torch.Tensor) -> torch.Tensor: x = x.view(torch.uint8).to(torch.int32) return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) def split_block_dims(blocks: torch.Tensor, *args): n_max = blocks.shape[1] dims = list(args) + [n_max - sum(args)] return torch.split(blocks, dims, dim=1) PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]