diff --git a/test/test_schedule.py b/test/test_schedule.py index 297eeed6e6..db0f51d0e2 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same from tinygrad.codegen.kernel import verify_ast -from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops +from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -66,6 +66,9 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) +@track_rewrites(named=True) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext()) + class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): a = Tensor.empty(10) @@ -2228,6 +2231,13 @@ class TestCopyFolding(unittest.TestCase): run_schedule(check_schedule(b, 0, filter_sink=False)) self.assertListEqual(b.tolist(), [0, 0, 0]) + def test_alu_after_copy(self): + a = Tensor.ones((4,)).to("CLANG").lazydata + b = Tensor.empty(4, device="CLANG").lazydata + add = a+b + add = schedule_graph_rewrite(add) + assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}" + class TestTensorUOpSpec(unittest.TestCase): def test_const_must_be_unmasked(self): a = Tensor.ones((4, 4)).pad((2, 2)) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 94404796ca..db0ecccfe8 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -380,8 +380,8 @@ ops_folding = symbolic_simple+PatternMatcher([ lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), # reduce of const is collapsed (TODO: make this a generic rule for stride0) (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), - # CONST doesn't need COPY - (UPat(Ops.COPY, src=(UPat(), UPat.cvar("x"),)), lambda x: x), + # COPY(CONST) creates a new CONST on the destination device + (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),