diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 692b3a746c..98446abb3c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -548,7 +548,7 @@ jobs: run: | RANGEIFY=1 python docs/abstractions2.py - name: Test const folding - run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding" + run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded" # RANGEIFY=2 isn't supported #- name: Test CPU=1 RANGEIFY=2 # run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20 diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index b3d295287f..3d544f1f89 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -115,9 +115,12 @@ def reduce_unparented(red:UOp): for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret -pm_reduce_simplify = PatternMatcher([ +pm_reduce_unparented = PatternMatcher([ # remove any ranges from a REDUCE that aren't referenced in the reduce source (UPat(Ops.REDUCE, name="red"), reduce_unparented), +]) + +pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([ # remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range (UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse), ]) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 06a930e64c..ab1c364cba 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -7,7 +7,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup from tinygrad.schedule.kernelize import Kernel from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType -from tinygrad.codegen.simplify import pm_flatten_range +from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt # ***************** @@ -699,7 +699,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # rangeify tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify") # NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right - tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding + tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers")