ggml_data_to_tensor touchups (#7196)

* ggml_data_to_tensor touchups

tiny reordering and variable name changes

* return type

* pylint
This commit is contained in:
chenyu
2024-10-21 13:29:59 -04:00
committed by GitHub
parent 815e1a340c
commit f93bd9e2b9

View File

@@ -227,33 +227,37 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
f.seek(rwd)
return TorchPickle(f).load()
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int):
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
Converts ggml tensor data to a tinygrad tensor.
Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q6_K (id: 14), Q8_0 (id: 8)
Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14)
"""
# https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356
bc_dtype = { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type, None)
if bc_dtype is not None: return t[:bc_dtype.itemsize * n].bitcast(bc_dtype)
# native types
if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None:
return t[:dtype.itemsize * n].bitcast(dtype)
def q_to_uint8(t: Tensor, b: int) -> Tensor:
# TODO: rewrite with arange?
shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b)
return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2)
blk_info = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type, None) # map to (number of elements, number of bytes)
blocks = t if blk_info is None else t[:(n//blk_info[0])*blk_info[1]].reshape((-1, blk_info[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 3:
d, m = tuple(blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
# map to (number of elements, number of bytes)
if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None:
blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1]))
if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32)
if ggml_type == 3:
d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ])
return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m
if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8)
if ggml_type == 14:
xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)
scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256))
d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256))
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]: