From 2d11765295e746735f0fc06bc320768270542dc6 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Fri, 29 Nov 2024 12:31:25 +0100 Subject: [PATCH] Fix WebGPU atomic store (#7954) --- test/test_arange.py | 2 +- test/test_jit.py | 1 - tinygrad/renderer/wgsl.py | 19 ++++++++++++------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/test/test_arange.py b/test/test_arange.py index e695672221..0cedfc1838 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -139,7 +139,7 @@ class TestIndexing(unittest.TestCase): np.testing.assert_equal(X.numpy(), 0) @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") - def test_index_mnist(self, noopt=1, op_limit=512*784*10): + def test_index_mnist(self, noopt=1, op_limit=512*784*13): from tinygrad.nn.datasets import mnist X_train, Y_train, _, _ = mnist() with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0): diff --git a/test/test_jit.py b/test/test_jit.py index e53e977bde..384688aad8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -153,7 +153,6 @@ class TestJit(unittest.TestCase): for _ in range(5): add(a) self.assertEqual(a.item(), 5) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "TODO: fix this bug in WebGPU") def test_jit_assign_int8(self): self.test_jit_assign(dtypes.int8) def test_kwargs_jit(self): diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index bddd1779d9..9433849cf0 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -6,7 +6,7 @@ from tinygrad.helpers import strip_parens import math # utility functions for handling packed load/store of < 4-byte data types: bool, char/uchar, short/ushort -packed_types = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32} +unpack_map = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32} def sign_extend(val:UOp, sext_am:int): return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \ @@ -14,16 +14,19 @@ def sign_extend(val:UOp, sext_am:int): # store for char: buf[idx/4] <- (var << (idx%4)*8)) def packed_store(bidx:UOp, var:UOp): + unpacked_type = unpack_map[var.dtype] shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize) new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am - return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), new_v.cast(packed_types[var.dtype])) + mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(unpacked_type) + buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=unpacked_type) + return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(unpacked_type))) # load for char: sign_extend(buf[idx/4] >> ((idx%4)*8)) def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None): div_idx = bidx.src[1]//(4//dtype.itemsize) shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize) - if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=packed_types[dtype], arg=root.arg) - else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=packed_types[dtype], arg=root.arg) + if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=unpack_map[dtype], arg=root.arg) + else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=unpack_map[dtype], arg=root.arg) val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF) return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype) @@ -36,7 +39,7 @@ wgsl_matcher = PatternMatcher([ lambda a,b: UOp(a, dtypes.float, (b.cast(dtypes.float),)).cast(b.dtype)) for a in (Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.SQRT)], (UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None), (UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())), - lambda l,b,c: packed_load(l,b,l.dtype,c.cast(packed_types[l.dtype])) if l.dtype.itemsize < 4 else None), + lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None), (UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None), (UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"), UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))), @@ -73,8 +76,10 @@ class WGSLRenderer(CStyleLanguage): (UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"), - (UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v")), allow_any_len=True),\ - lambda ctx,b,v: f"atomicAdd(&{ctx[b]}, {ctx[v]});" if b.src[0].dtype.itemsize < 4 else f"{ctx[b]} = {ctx[v]};"), + (UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\ + # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1] + f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \ + else f"{ctx[b]} = {ctx[v]};"), # fix nan check: 'a != a -> is_nan()' (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"), ]) + base_rewrite