diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 3b2458e460..0c23c43962 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -227,33 +227,37 @@ def torch_load(fn:str) -> Dict[str, Tensor]: f.seek(rwd) return TorchPickle(f).load() -def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int): +def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: """ Converts ggml tensor data to a tinygrad tensor. Supported native types: float32 (id: 0), float16 (id: 1), int8 (id: 16), int16 (id: 17), int32 (id: 18) - Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q6_K (id: 14), Q8_0 (id: 8) + Supported quantized types: Q4_0 (id: 2), Q4_1 (id: 3), Q8_0 (id: 8), Q6_K (id: 14) """ + # https://github.com/ggerganov/ggml/blob/6dccc647264f5429df2624f36138f601e7ce23e5/include/ggml.h#L356 - bc_dtype = { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type, None) - if bc_dtype is not None: return t[:bc_dtype.itemsize * n].bitcast(bc_dtype) + # native types + if (dtype := { 0: dtypes.float32, 1: dtypes.float16, 16: dtypes.int8, 17: dtypes.int16, 18: dtypes.int32 }.get(ggml_type)) is not None: + return t[:dtype.itemsize * n].bitcast(dtype) def q_to_uint8(t: Tensor, b: int) -> Tensor: + # TODO: rewrite with arange? shift_tensor, bitmask = Tensor.stack(*[ Tensor(2**(i*b), device=t.device, dtype=t.dtype) for i in range(8//b) ]), 0xff >> (8 - b) return t.unsqueeze(-1).expand((*t.shape,8//b)).div(shift_tensor, upcast=False).bitwise_and(bitmask).transpose(-1, -2).flatten(-2) - blk_info = { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type, None) # map to (number of elements, number of bytes) - blocks = t if blk_info is None else t[:(n//blk_info[0])*blk_info[1]].reshape((-1, blk_info[1])) - if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) - if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8) - if ggml_type == 3: - d, m = tuple(blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ]) - return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m - if ggml_type == 14: - xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4) - scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((blocks.shape[0], 16, 16)).reshape((-1, 256)) - d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256)) - return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales + # map to (number of elements, number of bytes) + if (nelements_nbytes := { 2: (32, 18), 3: (32, 20), 14: (256, 210), 8: (32, 34) }.get(ggml_type)) is not None: + blocks = t[:(n//nelements_nbytes[0])*nelements_nbytes[1]].reshape((-1, nelements_nbytes[1])) + if ggml_type == 2: return (q_to_uint8(blocks[:,2:], 4).bitcast(dtypes.int8) - 8) * blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) + if ggml_type == 3: + d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ]) + return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m + if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8) + if ggml_type == 14: + xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4) + scales = blocks[:,192:208].bitcast(dtypes.int8).unsqueeze(-1).expand((-1, 16, 16)).reshape((-1, 256)) + d = blocks[:,-2:].bitcast(dtypes.float16).cast(dtypes.float32).expand((-1, 256)) + return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales raise ValueError(f"GGML type '{ggml_type}' is not supported!") def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]: