Fix setitem on wgpu (#8144)

This commit is contained in:
Ahmed Harmouche
2024-12-10 19:34:25 +01:00
committed by GitHub
parent b69fea6ae5
commit 71dd222f66
2 changed files with 2 additions and 2 deletions

View File

@@ -380,7 +380,7 @@ jobs:
WEBGPU=1 WGPU_BACKEND_TYPE=Vulkan python3 -m pytest -n=auto test/test_assign.py test/test_arange.py test/test_const_folding.py test/test_dtype.py \
test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \
test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \
test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20
test/test_setitem.py test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20
- name: Run process replay tests
run: |
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")

View File

@@ -72,7 +72,7 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v")), allow_any_len=True),lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
else f"{ctx[b]} = {ctx[v]};"),