rangeify: INDEX doesn't passthrough MSELECT (#12279)

This commit is contained in:
qazal
2025-09-23 21:36:50 +03:00
committed by GitHub
parent 02a7b7fe48
commit ad7c8c21ea
2 changed files with 2 additions and 2 deletions

View File

@@ -531,7 +531,7 @@ jobs:
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
- name: Test multitensor
run: |
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W TestMultiTensor.test_simple_reduce
CPU=1 RANGEIFY=1 python3 -m pytest test/test_multitensor.py::TestMultiAssign -k 'not (multi_assign_piece_noncontig or multi_assign_var_offset)'
- name: Test CPU=1 RANGEIFY=2
run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20

View File

@@ -336,7 +336,7 @@ pm_rangeify = pm_mops+PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
# assert if there's any index we didn't process
(UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE}).f(Ops.INDEX, name="x"), unprocessed_index),
(UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE, Ops.MSELECT}).f(Ops.INDEX, name="x"), unprocessed_index),
])
# *****************