mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -40,8 +40,6 @@ def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p):
|
||||
return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx
|
||||
|
||||
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
|
||||
# TODO: WEBGPU GGUF dequantization produces incorrect values
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GGUF dequantization issue")
|
||||
class TestGGUF(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False)
|
||||
|
||||
@@ -69,9 +69,12 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
||||
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
|
||||
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.char, name="x"), lambda ctx,x: f"((i32({ctx[x.src[0]]}&0xFF)<<24)>>24)"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.ushort, name="x"), lambda ctx,x: f"bitcast<u32>(vec2<f16>({ctx[x.src[0]]},0))" \
|
||||
if x.src[0].dtype == dtypes.half else f"bitcast<u32>({ctx[x.src[0]]}&0xFFFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.short, name="x"), lambda ctx,x: f"bitcast<i32>(vec2<f16>({ctx[x.src[0]]},0))" \
|
||||
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),
|
||||
|
||||
Reference in New Issue
Block a user