rangeify: fix test_setitem (#12269)

* rangeify: fix test_setitem

* um?

* better?

* simple where folding

* f

* revert

* x
This commit is contained in:
nimlgen
2025-09-23 20:42:36 +03:00
committed by GitHub
parent 2f145a98e0
commit 02a7b7fe48
4 changed files with 9 additions and 7 deletions

View File

@@ -525,7 +525,8 @@ jobs:
CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
-k "not test_load_state_dict_sharded_model_dict_same_axis and not test_instancenorm_3d" \
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_tensor_variable.py \
test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py
test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py \
test/test_setitem.py
- name: Test const folding
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
- name: Test multitensor

View File

@@ -2108,7 +2108,6 @@ class TestView(unittest.TestCase):
# a*VIEW(x), where VIEW(x) = 0
# x+2
# as long as one child realizes, x does not collapse
@expect_rangeify_fails
def test_parent_multiple_children_no_collapse(self):
a = Tensor([1, 2])
b = Tensor.arange(3).contiguous()

View File

@@ -421,7 +421,7 @@ def bufferize_to_store(x:UOp):
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
if x.src[0].op is Ops.ASSIGN:
assign_target, assign_src, assign_mops = x.src[0].src
assert assign_target.op is Ops.INDEX
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
# in assign, this is the buffer size, not the bufferize size
# TODO: assign_mops here
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype)

View File

@@ -113,7 +113,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# new decomp rules for threefry
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0))
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0)),
# ** simple where folding **
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
])
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
@@ -292,9 +296,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** where folding **
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)
if f.arg is not Invalid else None),
# alu of two where with same conds can combine, only do if true branch or false branch is const