mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
Fix setitem on wgpu (#8144)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -380,7 +380,7 @@ jobs:
|
||||
WEBGPU=1 WGPU_BACKEND_TYPE=Vulkan python3 -m pytest -n=auto test/test_assign.py test/test_arange.py test/test_const_folding.py test/test_dtype.py \
|
||||
test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \
|
||||
test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \
|
||||
test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20
|
||||
test/test_setitem.py test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20
|
||||
- name: Run process replay tests
|
||||
run: |
|
||||
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
|
||||
|
||||
@@ -72,7 +72,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v")), allow_any_len=True),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
|
||||
Reference in New Issue
Block a user