Fix WebGPU atomic store (#7954)

This commit is contained in:
Ahmed Harmouche
2024-11-29 12:31:25 +01:00
committed by GitHub
parent 309dcb1044
commit 2d11765295
3 changed files with 13 additions and 9 deletions

View File

@@ -139,7 +139,7 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_equal(X.numpy(), 0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_index_mnist(self, noopt=1, op_limit=512*784*10):
def test_index_mnist(self, noopt=1, op_limit=512*784*13):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):

View File

@@ -153,7 +153,6 @@ class TestJit(unittest.TestCase):
for _ in range(5): add(a)
self.assertEqual(a.item(), 5)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "TODO: fix this bug in WebGPU")
def test_jit_assign_int8(self): self.test_jit_assign(dtypes.int8)
def test_kwargs_jit(self):

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import strip_parens
import math
# utility functions for handling packed load/store of < 4-byte data types: bool, char/uchar, short/ushort
packed_types = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32}
unpack_map = {dtypes.bool: dtypes.int, dtypes.char: dtypes.int, dtypes.uchar: dtypes.uint32, dtypes.short: dtypes.int, dtypes.ushort: dtypes.uint32}
def sign_extend(val:UOp, sext_am:int):
return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
@@ -14,16 +14,19 @@ def sign_extend(val:UOp, sext_am:int):
# store for char: buf[idx/4] <- (var << (idx%4)*8))
def packed_store(bidx:UOp, var:UOp):
unpacked_type = unpack_map[var.dtype]
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize)
new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), new_v.cast(packed_types[var.dtype]))
mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(unpacked_type)
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=unpacked_type)
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(unpacked_type)))
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
div_idx = bidx.src[1]//(4//dtype.itemsize)
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=packed_types[dtype], arg=root.arg)
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=packed_types[dtype], arg=root.arg)
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=unpack_map[dtype], arg=root.arg)
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=unpack_map[dtype], arg=root.arg)
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
@@ -36,7 +39,7 @@ wgsl_matcher = PatternMatcher([
lambda a,b: UOp(a, dtypes.float, (b.cast(dtypes.float),)).cast(b.dtype)) for a in (Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.SQRT)],
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(packed_types[l.dtype])) if l.dtype.itemsize < 4 else None),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"),
UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))),
@@ -73,8 +76,10 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v")), allow_any_len=True),\
lambda ctx,b,v: f"atomicAdd(&{ctx[b]}, {ctx[v]});" if b.src[0].dtype.itemsize < 4 else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
else f"{ctx[b]} = {ctx[v]};"),
# fix nan check: 'a != a -> is_nan()'
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
]) + base_rewrite