diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8373f33a82..aba1ed3b3a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -525,7 +525,8 @@ jobs: CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \ -k "not test_load_state_dict_sharded_model_dict_same_axis and not test_instancenorm_3d" \ test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_tensor_variable.py \ - test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py + test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py \ + test/test_setitem.py - name: Test const folding run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding" - name: Test multitensor diff --git a/test/test_schedule.py b/test/test_schedule.py index 2c0d8bd197..42facd5ee9 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2108,7 +2108,6 @@ class TestView(unittest.TestCase): # a*VIEW(x), where VIEW(x) = 0 # x+2 # as long as one child realizes, x does not collapse - @expect_rangeify_fails def test_parent_multiple_children_no_collapse(self): a = Tensor([1, 2]) b = Tensor.arange(3).contiguous() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index af496cfdb8..24efcd13c1 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -421,7 +421,7 @@ def bufferize_to_store(x:UOp): sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace) if x.src[0].op is Ops.ASSIGN: assign_target, assign_src, assign_mops = x.src[0].src - assert assign_target.op is Ops.INDEX + assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index" # in assign, this is the buffer size, not the bufferize size # TODO: assign_mops here ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index bd35de5bc2..bbb9458b32 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -113,7 +113,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([ # new decomp rules for threefry (((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), (((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x), - (UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0)) + (UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0)), + # ** simple where folding ** + # a conditional with the same results either way is a noop, also fold const conditionals + (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), + (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), ]) # ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ******** @@ -292,9 +296,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2), ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3) (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c - # a conditional with the same results either way is a noop, also fold const conditionals - (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), - (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), + # ** where folding ** (UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t) if f.arg is not Invalid else None), # alu of two where with same conds can combine, only do if true branch or false branch is const