rangeify: TestReduceOpsConstFolding (#12397)

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
b1tg
2025-10-01 17:58:19 +08:00
committed by GitHub
parent 60e52fbe36
commit ac3d457d5e
3 changed files with 7 additions and 4 deletions

View File

@@ -548,7 +548,7 @@ jobs:
run: |
RANGEIFY=1 python docs/abstractions2.py
- name: Test const folding
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding"
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded"
# RANGEIFY=2 isn't supported
#- name: Test CPU=1 RANGEIFY=2
# run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20

View File

@@ -115,9 +115,12 @@ def reduce_unparented(red:UOp):
for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
pm_reduce_simplify = PatternMatcher([
pm_reduce_unparented = PatternMatcher([
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
])
pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
])

View File

@@ -7,7 +7,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt
# *****************
@@ -699,7 +699,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# rangeify
tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify")
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers")