fix wgsl bitcast (#14600)

was wrong for signed int
This commit is contained in:
chenyu
2026-02-06 16:57:36 -05:00
committed by GitHub
parent b9fe8b7591
commit 7d193a6e26
2 changed files with 6 additions and 5 deletions

View File

@@ -40,8 +40,6 @@ def ggml_tensor_to_numpy(tensor: ggml.ggml_tensor_p):
return np.lib.stride_tricks.as_strided(output, shape=shape, strides=strides), ctx
@unittest.skipIf(any(not is_dtype_supported(t) for t in [ dtypes.uint8, dtypes.half ]), "Backend must support uint8 and half")
# TODO: WEBGPU GGUF dequantization produces incorrect values
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU GGUF dequantization issue")
class TestGGUF(unittest.TestCase):
def setUp(self) -> None:
params = ggml.ggml_init_params(mem_size=0, mem_buffer=None, no_alloc=False)

View File

@@ -69,9 +69,12 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
(UPat(Ops.BITCAST, dtype=dtypes.char, name="x"), lambda ctx,x: f"((i32({ctx[x.src[0]]}&0xFF)<<24)>>24)"),
(UPat(Ops.BITCAST, dtype=dtypes.ushort, name="x"), lambda ctx,x: f"bitcast<u32>(vec2<f16>({ctx[x.src[0]]},0))" \
if x.src[0].dtype == dtypes.half else f"bitcast<u32>({ctx[x.src[0]]}&0xFFFF)"),
(UPat(Ops.BITCAST, dtype=dtypes.short, name="x"), lambda ctx,x: f"bitcast<i32>(vec2<f16>({ctx[x.src[0]]},0))" \
if x.src[0].dtype == dtypes.half else f"((i32({ctx[x.src[0]]}&0xFFFF)<<16)>>16)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), allow_any_len=True),