From 7d193a6e260ddcb172d99bac3b06dbf86498eb3f Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 6 Feb 2026 16:57:36 -0500 Subject: [PATCH] fix wgsl bitcast (#14600) was wrong for signed int --- test/unit/test_gguf.py | 2 -- tinygrad/renderer/wgsl.py | 9 ++++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/unit/test_gguf.py b/test/unit/test_gguf.py index da9a7fe670..4cdff86581 100644 --- a/test/unit/test_gguf.py +++ b/test/unit/test_gguf.py @@ -40,8 +40,6 @@ def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p): return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx @unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half") -# TODO: WEBGPU GGUF dequantization produces incorrect values -@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GGUF dequantization issue") class TestGGUF(unittest.TestCase): def setUp(self) -> None: params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 51b8c8e24a..58da950557 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -69,9 +69,12 @@ class WGSLRenderer(CStyleLanguage): (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"), (UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)), lambda ctx,x: f"bitcast>({ctx[x.src[0]]})[0]"), - (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"), - (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2({ctx[x.src[0]]},0))" \ - if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"), + (UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast({ctx[x.src[0]]}&0xFF)"), + (UPat(Ops.BITCAST, dtype=dtypes.char, name="x"), lambda ctx,x: f"((i32({ctx[x.src[0]]}&0xFF)<<24)>>24)"), + (UPat(Ops.BITCAST, dtype=dtypes.ushort, name="x"), lambda ctx,x: f"bitcast(vec2({ctx[x.src[0]]},0))" \ + if x.src[0].dtype == dtypes.half else f"bitcast({ctx[x.src[0]]}&0xFFFF)"), + (UPat(Ops.BITCAST, dtype=dtypes.short, name="x"), lambda ctx,x: f"bitcast(vec2({ctx[x.src[0]]},0))" \ + if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"), # TODO: load alt value doesnt have to be a const (UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),