mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
[webgpu] Fix atomic shared mem load inside loop (#10530)
* Disable shared mem atomics on webgpu * allow_any_len in load pattern matcher to fix temp load inside loop
This commit is contained in:
@@ -189,8 +189,6 @@ class TestIndexing(unittest.TestCase):
|
||||
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
|
||||
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
|
||||
|
||||
# TODO: fix these on WEBGPU, it looks like it has to do with packed stuff
|
||||
@unittest.skipIf(getenv("WEBGPU"), "broken on webgpu for some reason")
|
||||
def test_index_mnist_opt(self): self.test_index_mnist(0)
|
||||
def test_index_mnist_split(self): self.test_index_mnist(1, split_reduceop=1)
|
||||
def test_index_mnist_opt_split(self): self.test_index_mnist(0, split_reduceop=1)
|
||||
|
||||
@@ -30,9 +30,8 @@ def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.ha
|
||||
wgsl_matcher = PatternMatcher([
|
||||
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
||||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"),)), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"))),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), name="l"),lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
|
||||
# TODO: why is this needed, and only for this MUL order
|
||||
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
|
||||
|
||||
Reference in New Issue
Block a user