mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
WEBGPU/NIR truncates ints (#14307)
* WEBGPU truncates ints * nir has this bug too
This commit is contained in:
committed by
GitHub
parent
26220a472e
commit
e782d44918
@@ -3281,6 +3281,10 @@ class TestOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
||||
|
||||
def test_int_or(self):
|
||||
t = (Tensor([0], dtype='int') | 0xFFFFFFFF).item()
|
||||
if not COMPILE_ONLY: assert t == -1
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Callable, cast, Any
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -70,7 +70,7 @@ def nchannel(b:mesa.nir_builder, src:mesa.nir_def, c:int):
|
||||
|
||||
def nimm_set(imm:mesa.nir_def, x, dtype:DType):
|
||||
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
|
||||
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, x)
|
||||
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x))
|
||||
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes, AddrSpace
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes, truncate, AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
|
||||
from tinygrad.helpers import strip_parens
|
||||
@@ -64,6 +64,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"),
|
||||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
||||
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
||||
|
||||
Reference in New Issue
Block a user