mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rangeify: TestReduceOpsConstFolding (#12397)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -548,7 +548,7 @@ jobs:
|
||||
run: |
|
||||
RANGEIFY=1 python docs/abstractions2.py
|
||||
- name: Test const folding
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding"
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded"
|
||||
# RANGEIFY=2 isn't supported
|
||||
#- name: Test CPU=1 RANGEIFY=2
|
||||
# run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
|
||||
@@ -115,9 +115,12 @@ def reduce_unparented(red:UOp):
|
||||
for r in reduce_unparented: ret = ret ** r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
pm_reduce_simplify = PatternMatcher([
|
||||
pm_reduce_unparented = PatternMatcher([
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
])
|
||||
|
||||
pm_reduce_simplify = pm_reduce_unparented + PatternMatcher([
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
||||
])
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# *****************
|
||||
@@ -699,7 +699,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
# rangeify
|
||||
tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify")
|
||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
||||
tsink = graph_rewrite(tsink, symbolic_simple, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user