remove support for checking tensor uops in FUSE_ARANGE [pr] (#8829)

This commit is contained in:
qazal
2025-01-31 00:48:28 -05:00
committed by GitHub
parent 2a33750e4c
commit a78f0f85d3
2 changed files with 8 additions and 10 deletions

View File

@@ -1650,6 +1650,7 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
@unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()")
def test_arange_index_contiguous(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
@@ -1666,6 +1667,7 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
@unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()")
def test_arange_index_contiguous_child(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
@@ -1732,6 +1734,7 @@ class TestIndexing(unittest.TestCase):
run_schedule(check_schedule(ref, 3))
np.testing.assert_equal(fused.numpy(), ref.numpy())
@unittest.skip("TOOD: FUSE_ARANGE overrules this contiguous")
def test_fuse_assign_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)

View File

@@ -3,7 +3,7 @@ from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
from tinygrad.dtype import DType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -267,7 +267,6 @@ def group_realizes(ctx:ScheduleContext) -> None:
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
reduce_of_const: list[UOp] = []
double_reduces: list[UOp] = []
for r, r_uop in ctx.allbufs.items():
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
@@ -308,18 +307,14 @@ def group_realizes(ctx:ScheduleContext) -> None:
group = {tr: None}
ctx.realizes[tr] = tr
reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST: reduce_of_const.append(r)
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
# maybe fuse arange with its children
if len(flatten(ctx.children[tr] for tr in group)) != 0:
for tr in group: del ctx.realizes[tr]
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(tensor_uop.op is Ops.CONTIGUOUS for tr in group for tensor_uop in ctx.tensor_uops[tr]): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
# **** Schedule creation and BFS toposort