WEBGPU/NIR truncates ints (#14307)

* WEBGPU truncates ints

* nir has this bug too
This commit is contained in:
Christopher Milan
2026-01-23 16:28:06 -08:00
committed by GitHub
parent 26220a472e
commit e782d44918
3 changed files with 8 additions and 3 deletions

View File

@@ -3281,6 +3281,10 @@ class TestOps(unittest.TestCase):
def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
def test_int_or(self):
t = (Tensor([0], dtype='int') | 0xFFFFFFFF).item()
if not COMPILE_ONLY: assert t == -1
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase):
def test_cast(self):

View File

@@ -1,5 +1,5 @@
from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -70,7 +70,7 @@ def nchannel(b:mesa.nir_builder, src:mesa.nir_def, c:int):
def nimm_set(imm:mesa.nir_def, x, dtype:DType):
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x)
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x))
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:

View File

@@ -1,4 +1,4 @@
from tinygrad.dtype import DType, PtrDType, dtypes, AddrSpace
from tinygrad.dtype import DType, PtrDType, dtypes, truncate, AddrSpace
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
from tinygrad.helpers import strip_parens
@@ -64,6 +64,7 @@ class WGSLRenderer(CStyleLanguage):
(UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"),
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),