fix some assigns on rangeify (#11774)

* fix some assigns

* llvm test

* more tests

* upd test
This commit is contained in:
George Hotz
2025-08-21 15:15:54 -07:00
committed by GitHub
parent 2e0eb88549
commit 5954a0975f
3 changed files with 30 additions and 14 deletions

View File

@@ -601,13 +601,22 @@ jobs:
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: rangeify-minimal
key: rangeify-minimal-llvm
deps: testing_minimal
llvm: "true"
- name: Test CPU=1 RANGEIFY=1
# TODO: add more passing tests here
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
# test_symbolic_arange_sym_step is passing now
# test_threefry_doesnt_use_long is because there's a contig after the long now
run: |
CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
-k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py test/test_tensor_data.py
- name: Test CPU=1 RANGEIFY=1 PARTIAL_CONTIG=1
run: PARTIAL_CONTIG=1 CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
- name: Test LLVM=1 RANGEIFY=1 (slow tests)
run: LLVM=1 RANGEIFY=1 python3 -m pytest -n auto test/models/test_mnist.py --durations 20
testdevectorize:
name: Linux (devectorize)

View File

@@ -10,7 +10,12 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, K
# 0. do some cleanup rewrites, mostly copied from the old stuff
earliest_rewrites = PatternMatcher([
double_reshape = PatternMatcher([
# RESHAPE on RESHAPE is the second reshape
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
])
earliest_rewrites = double_reshape+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
@@ -18,8 +23,6 @@ earliest_rewrites = PatternMatcher([
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# RESHAPE on RESHAPE is the second reshape
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
# non shape changing RESHAPE is NOOP
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
# RESHAPE after COPY
@@ -49,6 +52,9 @@ def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
@@ -56,6 +62,8 @@ do_realize = PatternMatcher([
(UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
# realize input to assign (might be optimized out)
(UPat(Ops.ASSIGN, name="a"), realize_assign),
])
add_contiguous = PatternMatcher([
@@ -153,18 +161,18 @@ def map_expand(r:UOp, idx:UOp):
pm_mops = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"),
(UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"),
(UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"),
(UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
# expand needs to end ranges
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand),
(UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand),
# reshape does a lot of symbolic stuff
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape),
(UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape),
# pad adds min and max
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad),
(UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad),
])
def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
@@ -189,8 +197,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
ranges = []
for s in x.shape:
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
ret = x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
return ret.forced_reshape(x.shape)
return x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape)
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
rngs = list(idx.src[1:])
@@ -304,7 +311,7 @@ def remove_bufferize(b2:UOp, idx2:UOp):
assert all(x.op is Ops.RANGE for x in b2.src[1:])
return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:])))
pm_cleanups = pm_mops+PatternMatcher([
pm_cleanups = double_reshape+pm_mops+PatternMatcher([
#(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
# remove noop buffers. if we look at the next index we can remove even more of these
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more