From e782d449180385e10264014e6be9f547bf4be14c Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Fri, 23 Jan 2026 16:28:06 -0800 Subject: [PATCH] WEBGPU/NIR truncates ints (#14307) * WEBGPU truncates ints * nir has this bug too --- test/test_ops.py | 4 ++++ tinygrad/renderer/nir.py | 4 ++-- tinygrad/renderer/wgsl.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index f52d9e6cf0..39d70e661d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3281,6 +3281,10 @@ class TestOps(unittest.TestCase): def test_bitcast(self): helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True) + def test_int_or(self): + t = (Tensor([0], dtype='int') | 0xFFFFFFFF).item() + if not COMPILE_ONLY: assert t == -1 + @unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}") class TestOpsUint8(unittest.TestCase): def test_cast(self): diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index b562f16743..3d53f72a1b 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -1,5 +1,5 @@ from typing import Callable, cast, Any -from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes +from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -70,7 +70,7 @@ def nchannel(b:mesa.nir_builder, src:mesa.nir_def, c:int): def nimm_set(imm:mesa.nir_def, x, dtype:DType): instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr)) - struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x) + struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x)) @nir_instr(nc=1, bs=lambda dtype: dtype.bitsize) def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def: diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 87a848660b..51b8c8e24a 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -1,4 +1,4 @@ -from tinygrad.dtype import DType, PtrDType, dtypes, AddrSpace +from tinygrad.dtype import DType, PtrDType, dtypes, truncate, AddrSpace from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm from tinygrad.helpers import strip_parens @@ -64,6 +64,7 @@ class WGSLRenderer(CStyleLanguage): (UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"), (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda x: f"bitcast({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"), + (UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"), (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"), (UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),