mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove support for checking tensor uops in FUSE_ARANGE [pr] (#8829)
This commit is contained in:
@@ -1650,6 +1650,7 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
@unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()")
|
||||
def test_arange_index_contiguous(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
@@ -1666,6 +1667,7 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
@unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()")
|
||||
def test_arange_index_contiguous_child(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
@@ -1732,6 +1734,7 @@ class TestIndexing(unittest.TestCase):
|
||||
run_schedule(check_schedule(ref, 3))
|
||||
np.testing.assert_equal(fused.numpy(), ref.numpy())
|
||||
|
||||
@unittest.skip("TOOD: FUSE_ARANGE overrules this contiguous")
|
||||
def test_fuse_assign_contiguous(self):
|
||||
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
|
||||
a = Tensor.arange(8).reshape(4, 2)
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
|
||||
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap, flatten
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -267,7 +267,6 @@ def group_realizes(ctx:ScheduleContext) -> None:
|
||||
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: dict[UOp, UOp] = {}
|
||||
reduce_of_const: list[UOp] = []
|
||||
double_reduces: list[UOp] = []
|
||||
for r, r_uop in ctx.allbufs.items():
|
||||
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
|
||||
@@ -308,18 +307,14 @@ def group_realizes(ctx:ScheduleContext) -> None:
|
||||
group = {tr: None}
|
||||
ctx.realizes[tr] = tr
|
||||
reduce_for_op.update((tr, r) for tr in group)
|
||||
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST: reduce_of_const.append(r)
|
||||
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.CONST:
|
||||
# maybe fuse arange with its children
|
||||
if len(flatten(ctx.children[tr] for tr in group)) != 0:
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
# fuse double reduces with no other child
|
||||
for reduceop in double_reduces:
|
||||
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
|
||||
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
||||
# maybe fuse arange with its children
|
||||
for rbuf in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
||||
if any(tensor_uop.op is Ops.CONTIGUOUS for tr in group for tensor_uop in ctx.tensor_uops[tr]): continue
|
||||
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
|
||||
Reference in New Issue
Block a user