mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Fix WebGPU atomic store (#7954)
This commit is contained in:
@@ -139,7 +139,7 @@ class TestIndexing(unittest.TestCase):
|
||||
np.testing.assert_equal(X.numpy(), 0)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_index_mnist(self, noopt=1, op_limit=512*784*10):
|
||||
def test_index_mnist(self, noopt=1, op_limit=512*784*13):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
|
||||
@@ -153,7 +153,6 @@ class TestJit(unittest.TestCase):
|
||||
for _ in range(5): add(a)
|
||||
self.assertEqual(a.item(), 5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "TODO: fix this bug in WebGPU")
|
||||
def test_jit_assign_int8(self): self.test_jit_assign(dtypes.int8)
|
||||
|
||||
def test_kwargs_jit(self):
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.helpers import strip_parens
|
||||
import math
|
||||
|
||||
# utility functions for handling packed load/store of < 4-byte data types: bool, char/uchar, short/ushort
|
||||
packed_types = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32}
|
||||
unpack_map = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32}
|
||||
|
||||
def sign_extend(val:UOp, sext_am:int):
|
||||
return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
|
||||
@@ -14,16 +14,19 @@ def sign_extend(val:UOp, sext_am:int):
|
||||
|
||||
# store for char: buf[idx/4] <- (var << (idx%4)*8))
|
||||
def packed_store(bidx:UOp, var:UOp):
|
||||
unpacked_type = unpack_map[var.dtype]
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize)
|
||||
new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), new_v.cast(packed_types[var.dtype]))
|
||||
mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(unpacked_type)
|
||||
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=unpacked_type)
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(unpacked_type)))
|
||||
|
||||
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
|
||||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
|
||||
div_idx = bidx.src[1]//(4//dtype.itemsize)
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
|
||||
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=packed_types[dtype], arg=root.arg)
|
||||
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=packed_types[dtype], arg=root.arg)
|
||||
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=unpack_map[dtype], arg=root.arg)
|
||||
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=unpack_map[dtype], arg=root.arg)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
@@ -36,7 +39,7 @@ wgsl_matcher = PatternMatcher([
|
||||
lambda a,b: UOp(a, dtypes.float, (b.cast(dtypes.float),)).cast(b.dtype)) for a in (Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.SQRT)],
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(packed_types[l.dtype])) if l.dtype.itemsize < 4 else None),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
|
||||
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"),
|
||||
UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))),
|
||||
@@ -73,8 +76,10 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v")), allow_any_len=True),\
|
||||
lambda ctx,b,v: f"atomicAdd(&{ctx[b]}, {ctx[v]});" if b.src[0].dtype.itemsize < 4 else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
|
||||
]) + base_rewrite
|
||||
|
||||
Reference in New Issue
Block a user