diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 3439e313b0..093957a37a 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -319,25 +319,16 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: d, m = (blocks[:,s:s+2].bitcast(dtypes.float16).cast(dtypes.float32) for s in [ 0, 2 ]) return q_to_uint8(blocks[:,4:], 4).bitcast(dtypes.int8) * d + m if ggml_type == 8: return blocks[:,:2].bitcast(dtypes.float16).cast(dtypes.float32) * blocks[:,2:].bitcast(dtypes.int8) - if ggml_type == 12: # Q4_K: 256 elements per 144-byte block (d:2, dmin:2, scales:12, qs:128) + # Q4_K: 256 elements per 144-byte block (d:2, dmin:2, scales:12, qs:128) + # Q5_K: 256 elements per 176-byte block (d:2, dmin:2, scales:12, qh:32, qs:128) + if ggml_type in (12, 13): d, dmin = (blocks[:,i:i+2].bitcast(dtypes.float16).cast(dtypes.float32).unsqueeze(-1) for i in [0, 2]) s = blocks[:,4:16] # 12 bytes: 6-bit scales[0-3], 6-bit mins[0-3], high bits[4-7] sc = s[:,0:4].bitwise_and(63).cat(s[:,8:12].bitwise_and(0xF).bitwise_or(s[:,0:4].rshift(6).lshift(4)), dim=-1) mn = s[:,4:8].bitwise_and(63).cat(s[:,8:12].rshift(4).bitwise_or(s[:,4:8].rshift(6).lshift(4)), dim=-1) - q = Tensor.stack((qs:=blocks[:,16:144].reshape(-1,4,32)).bitwise_and(0xF), qs.rshift(4), dim=2).reshape(-1,8,32).cast(dtypes.float32) - return (d * sc.unsqueeze(-1) * q - dmin * mn.unsqueeze(-1)).flatten(-2) - if ggml_type == 13: # Q5_K: 256 elements per 176-byte block (d:2, dmin:2, scales:12, qh:32, qs:128) - d, dmin = (blocks[:,i:i+2].bitcast(dtypes.float16).cast(dtypes.float32).unsqueeze(-1) for i in [0, 2]) - s = blocks[:,4:16] # 12 bytes: same scale packing as Q4_K - sc = s[:,0:4].bitwise_and(63).cat(s[:,8:12].bitwise_and(0xF).bitwise_or(s[:,0:4].rshift(6).lshift(4)), dim=-1) - mn = s[:,4:8].bitwise_and(63).cat(s[:,8:12].rshift(4).bitwise_or(s[:,4:8].rshift(6).lshift(4)), dim=-1) - qh = blocks[:,16:48].reshape(-1, 1, 32) # (nblocks, 1, 32) high bits - ql = blocks[:,48:176].reshape(-1, 4, 32) # (nblocks, 4, 32) low nibbles - # for each of 4 groups of 64: low = ql & 0xF, high = ql >> 4; the 5th bit comes from qh with shifting masks - u = Tensor([1, 4, 16, 64], device=blocks.device, dtype=dtypes.uint8).reshape(1, 4, 1) # bit masks for qh: 1,4,16,64 = bits 0,2,4,6 - qlo = Tensor.stack(ql.bitwise_and(0xF), ql.rshift(4), dim=2).reshape(-1, 8, 32) # (nblocks, 8, 32) 4-bit values - qhi = Tensor.stack(qh.bitwise_and(u), qh.bitwise_and(u.lshift(1)), dim=2).reshape(-1, 8, 32) # high bits - q = (qlo + (qhi != 0).where(16, 0)).cast(dtypes.float32) + qs_off = 48 if ggml_type == 13 else 16 + q = Tensor.stack((qs:=blocks[:,qs_off:qs_off+128].reshape(-1,4,32)).bitwise_and(0xF), qs.rshift(4), dim=2).reshape(-1,8,32) + if ggml_type == 13: q = q + q_to_uint8(blocks[:,16:48], 1).reshape(-1, 8, 32) * 16 return (d * sc.unsqueeze(-1) * q - dmin * mn.unsqueeze(-1)).flatten(-2) if ggml_type == 14: xl, xh = q_to_uint8(blocks[:,:128].reshape((-1, 2, 64)), 4), q_to_uint8(blocks[:,128:192].reshape((-1, 2, 32)), 2).lshift(4)