mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix some assigns on rangeify (#11774)
* fix some assigns * llvm test * more tests * upd test
This commit is contained in:
13
.github/workflows/test.yml
vendored
13
.github/workflows/test.yml
vendored
@@ -601,13 +601,22 @@ jobs:
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: rangeify-minimal
|
||||
key: rangeify-minimal-llvm
|
||||
deps: testing_minimal
|
||||
llvm: "true"
|
||||
- name: Test CPU=1 RANGEIFY=1
|
||||
# TODO: add more passing tests here
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
# test_symbolic_arange_sym_step is passing now
|
||||
# test_threefry_doesnt_use_long is because there's a contig after the long now
|
||||
run: |
|
||||
CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
|
||||
-k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \
|
||||
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \
|
||||
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py test/test_tensor_data.py
|
||||
- name: Test CPU=1 RANGEIFY=1 PARTIAL_CONTIG=1
|
||||
run: PARTIAL_CONTIG=1 CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
- name: Test LLVM=1 RANGEIFY=1 (slow tests)
|
||||
run: LLVM=1 RANGEIFY=1 python3 -m pytest -n auto test/models/test_mnist.py --durations 20
|
||||
|
||||
testdevectorize:
|
||||
name: Linux (devectorize)
|
||||
|
||||
@@ -10,7 +10,12 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, K
|
||||
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
|
||||
earliest_rewrites = PatternMatcher([
|
||||
double_reshape = PatternMatcher([
|
||||
# RESHAPE on RESHAPE is the second reshape
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
|
||||
])
|
||||
|
||||
earliest_rewrites = double_reshape+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
|
||||
@@ -18,8 +23,6 @@ earliest_rewrites = PatternMatcher([
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# RESHAPE on RESHAPE is the second reshape
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
|
||||
# non shape changing RESHAPE is NOOP
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
||||
# RESHAPE after COPY
|
||||
@@ -49,6 +52,9 @@ def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
for s in rb.src:
|
||||
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
@@ -56,6 +62,8 @@ do_realize = PatternMatcher([
|
||||
(UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||
# realize parents of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
||||
# realize input to assign (might be optimized out)
|
||||
(UPat(Ops.ASSIGN, name="a"), realize_assign),
|
||||
])
|
||||
|
||||
add_contiguous = PatternMatcher([
|
||||
@@ -153,18 +161,18 @@ def map_expand(r:UOp, idx:UOp):
|
||||
|
||||
pm_mops = PatternMatcher([
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"),
|
||||
(UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"),
|
||||
(UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"),
|
||||
(UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
# expand needs to end ranges
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand),
|
||||
(UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand),
|
||||
# reshape does a lot of symbolic stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape),
|
||||
(UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape),
|
||||
# pad adds min and max
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad),
|
||||
(UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad),
|
||||
])
|
||||
|
||||
def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
|
||||
@@ -189,8 +197,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
|
||||
ranges = []
|
||||
for s in x.shape:
|
||||
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
|
||||
ret = x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
|
||||
return ret.forced_reshape(x.shape)
|
||||
return x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape)
|
||||
|
||||
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||
rngs = list(idx.src[1:])
|
||||
@@ -304,7 +311,7 @@ def remove_bufferize(b2:UOp, idx2:UOp):
|
||||
assert all(x.op is Ops.RANGE for x in b2.src[1:])
|
||||
return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:])))
|
||||
|
||||
pm_cleanups = pm_mops+PatternMatcher([
|
||||
pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
#(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
|
||||
|
||||
Reference in New Issue
Block a user