mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
rangeify: fix test_setitem (#12269)
* rangeify: fix test_setitem * um? * better? * simple where folding * f * revert * x
This commit is contained in:
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -525,7 +525,8 @@ jobs:
|
||||
CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
|
||||
-k "not test_load_state_dict_sharded_model_dict_same_axis and not test_instancenorm_3d" \
|
||||
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_tensor_variable.py \
|
||||
test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py
|
||||
test/test_outerworld_range.py test/test_randomness.py test/test_nn.py test/test_arange.py test/test_tensor.py test/test_optim.py \
|
||||
test/test_setitem.py
|
||||
- name: Test const folding
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
|
||||
- name: Test multitensor
|
||||
|
||||
@@ -2108,7 +2108,6 @@ class TestView(unittest.TestCase):
|
||||
# a*VIEW(x), where VIEW(x) = 0
|
||||
# x+2
|
||||
# as long as one child realizes, x does not collapse
|
||||
@expect_rangeify_fails
|
||||
def test_parent_multiple_children_no_collapse(self):
|
||||
a = Tensor([1, 2])
|
||||
b = Tensor.arange(3).contiguous()
|
||||
|
||||
@@ -421,7 +421,7 @@ def bufferize_to_store(x:UOp):
|
||||
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
||||
if x.src[0].op is Ops.ASSIGN:
|
||||
assign_target, assign_src, assign_mops = x.src[0].src
|
||||
assert assign_target.op is Ops.INDEX
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
# in assign, this is the buffer size, not the bufferize size
|
||||
# TODO: assign_mops here
|
||||
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype)
|
||||
|
||||
@@ -113,7 +113,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
# new decomp rules for threefry
|
||||
(((UPat.var(None, dtypes.uint64)<<32) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||
(((UPat.var('x', dtypes.uint64)<<32) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))>>32, lambda x: x),
|
||||
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0))
|
||||
(UPat.var('b').where(UPat.var('x', dtypes.uint32).cast(dtypes.uint64), UPat.const(dtypes.uint64, 0)).cast(dtypes.uint32), lambda b,x: b.where(x,0)),
|
||||
# ** simple where folding **
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
])
|
||||
|
||||
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
||||
@@ -292,9 +296,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** where folding **
|
||||
(UPat.var("cond", dtype=dtypes.bool).logical_not().where(UPat.var("t"), UPat.var("f")), lambda cond, t, f: cond.where(f,t)
|
||||
if f.arg is not Invalid else None),
|
||||
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
||||
|
||||
Reference in New Issue
Block a user