[webgpu] Fix atomic shared mem load inside loop (#10530)

* Disable shared mem atomics on webgpu

* allow_any_len in load pattern matcher to fix temp load inside loop
This commit is contained in:
Ahmed Harmouche
2025-05-31 15:29:02 +02:00
committed by GitHub
parent 6af4b02374
commit 35eb4d357a
2 changed files with 2 additions and 5 deletions

View File

@@ -189,8 +189,6 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
# TODO: fix these on WEBGPU, it looks like it has to do with packed stuff
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
def test_index_mnist_opt(self): self.test_index_mnist(0)
def test_index_mnist_split(self): self.test_index_mnist(1, split_reduceop=1)
def test_index_mnist_opt_split(self): self.test_index_mnist(0, split_reduceop=1)

View File

@@ -30,9 +30,8 @@ def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.ha
wgsl_matcher = PatternMatcher([
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"),)), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"))),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
(UPat.load(UPat.var("b"), UPat.cvar("c"), name="l"),lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
# TODO: why is this needed, and only for this MUL order
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),