diff --git a/test/test_schedule.py b/test/test_schedule.py index d0584c1631..330b20231d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -11,7 +11,7 @@ from tinygrad import nn, dtypes, Device, Tensor from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp, view_left +from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites, merge_views, view_left from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.spec import type_verify, shape_spec from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp @@ -19,7 +19,6 @@ from tinygrad.engine.grouper import view_right, sym from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis -remove_movement_ops = merge_views def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec) class KernelCountException(Exception): pass @@ -69,7 +68,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) @track_rewrites(named=True) -def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {}) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, merge_views+sym, {}) class TestSchedule(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") @@ -623,7 +622,7 @@ class TestSchedule(unittest.TestCase): def test_pow_const_tensor_to_zero(self): x = Tensor([1,2,3,4]) out = x ** Tensor(0.0) - # NOTE: this is ConstBuffer 0 + ConstBuffer 1 + # NOTE: this is UOp.const(0) + UOp.const(1) check_schedule(out, 0) def test_zero_size(self): @@ -643,7 +642,6 @@ class TestSchedule(unittest.TestCase): out = x.sum(1).relu().elu() + y.sum(1).relu().elu() check_schedule(out, 2) - # multireduce spec @unittest.skipUnless(SPLIT_REDUCEOP, "Testing split reducop requires SPLIT_REDUCEOP") def test_preserve_multistage_reduce(self): big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768) @@ -664,7 +662,6 @@ class TestSchedule(unittest.TestCase): out = x.relu().sum(1) + out2[0] check_schedule(out, 2) - # multireduce spec @unittest.skip("these two Tensors are the same") def test_example_matmul(self): x = Tensor.eye(64, requires_grad=True) @@ -712,9 +709,9 @@ class TestSchedule(unittest.TestCase): x = x.sum(1) x = x[:16] out = x + y - check_schedule(out, 2) # TODO: this should be 1 + # NOTE: this could be 1 kernel if we mask the store? + check_schedule(out, 2) - # multireduce spec def test_multireduce_shrink(self): Tensor.manual_seed(0) a = Tensor.randn(32, 32).realize() @@ -737,7 +734,6 @@ class TestSchedule(unittest.TestCase): out = x.contiguous() + y.contiguous() check_schedule(out, 2, filter_sink=False) - # multireduce spec @unittest.expectedFailure def test_reduce_same_size(self): Tensor.manual_seed(0) @@ -750,7 +746,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6) - # multireduce spec @unittest.expectedFailure def test_reduce_multiple_paths(self): Tensor.manual_seed(0) @@ -762,7 +757,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6) - # multireduce spec def test_multireduce_reduce_multiple_paths(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -779,7 +773,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out2.numpy(), np_out2:=np.exp2(np_b.sum()), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out3.numpy(), np_b.sum()+np_out2, atol=1e-4, rtol=1e-4) - # multireduce spec def test_reduce_ext_reduce_child(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -792,7 +785,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4) - # multireduce spec def test_reduce_multiple_paths_midreduce(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -808,7 +800,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out1.numpy(), out1_np:=(a.numpy() - out0_np).max(), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out2.numpy(), r_np + out1_np, atol=1e-4, rtol=1e-4) - # multireduce spec def test_reduce_multiple_paths_midreduce_fused(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -822,7 +813,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6) np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6) - # multireduce spec def test_reduce_multiple_paths_midexpand(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4).realize() @@ -873,7 +863,6 @@ class TestSchedule(unittest.TestCase): out1 = out0[0] + Tensor.empty(1, ) check_schedule([r, out0, out1], 3) - # multireduce spec def test_std_multireduce_fusion(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -881,7 +870,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_argmin_multireduce_fusion(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -889,7 +877,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 3)) np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1)) - # multireduce spec def test_argmax_multireduce_fusion(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -909,7 +896,6 @@ class TestSchedule(unittest.TestCase): compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy())) np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3) - # multireduce spec def test_ugly_reduceop_pairing(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() @@ -921,7 +907,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out.numpy(), \ (c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_reduce_expand_reduce_fusion(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() @@ -930,7 +915,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_reduce_expand_reduce_expand_fusion(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() @@ -940,7 +924,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out.numpy(), \ a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4) - # multireduce spec def test_branching_reduces_and_expands_fusion(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() @@ -951,7 +934,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_fusion_simple_sequential(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -961,7 +943,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (y.numpy() + x.numpy().sum(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_fusion_simple_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -971,7 +952,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_fusion_sequential(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -980,7 +960,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_fusion_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -990,7 +969,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 4)) np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_diffops_sequential(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -999,7 +977,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_fusion_diffops_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -1009,7 +986,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multireduce_fusion_sequential_and_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() @@ -1023,7 +999,6 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4) - # multireduce spec def test_multimatmul_fusion(self): Tensor.manual_seed(0) a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize() @@ -1228,7 +1203,6 @@ class TestSchedule(unittest.TestCase): schedule = check_schedule([b, c], 3) self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) - # multireduce spec def test_multireduce_simple_chase(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4, 4).realize() @@ -1252,7 +1226,6 @@ class TestSchedule(unittest.TestCase): schedule = check_schedule([d, e], 3) self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) - # multireduce spec def test_multireduce_push_permute_chase(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4, 4).realize() @@ -1275,7 +1248,6 @@ class TestSchedule(unittest.TestCase): schedule = check_schedule(d, 2) self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) - # multireduce spec def test_multireduce_push_shrink_chase(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() @@ -1296,7 +1268,6 @@ class TestSchedule(unittest.TestCase): schedule = check_schedule(b, 2) self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS) - # multireduce spec def test_multireduce_midreduce_nochase(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() @@ -1376,7 +1347,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6) - # multireduce spec def test_multireduce_pad_reduce_safe(self): Tensor.manual_seed(0) a = Tensor.randn(3, 4, 5).realize() @@ -1394,7 +1364,6 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6) - # multireduce spec def test_multireduce_pad_reduce_unsafe(self): Tensor.manual_seed(0) a = Tensor.randn(3, 4, 5).abs().realize() @@ -1471,11 +1440,10 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(tiny_ret, p) def test_bitcast_fuses(self): - x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata) - a = x.alu(Ops.EXP2).bitcast(dtypes.int32) + x = Tensor.empty(1, dtype=dtypes.float32) + a = x.exp2().bitcast(dtypes.int32) b = x.bitcast(dtypes.int32) - b = a.alu(Ops.ADD, b) - check_schedule(b, 1) # this should fuse when it makes sense + check_schedule(a+b, 1) # this should fuse when it makes sense @unittest.skip("disabling subbuffer manually isn't supported anymore") def test_bitcast_disable_subbufer(self): @@ -1503,7 +1471,6 @@ class TestSchedule(unittest.TestCase): @unittest.skip("splitting kernels exceeding device buffer count is not yet supported") def _test_buf_cnt(self, cnt:int, allowed:int): - #if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}") alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)]) s = alu.schedule() assert len(s) == allowed @@ -1517,6 +1484,7 @@ class TestSchedule(unittest.TestCase): @unittest.expectedFailure def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3) + @unittest.skipIf(getenv("VIZ"), "TODO: VIZ blocks gc") def test_schedule_mem_used(self): base = GlobalCounters.mem_used Tensor.ones(256).contiguous().realize() @@ -1648,7 +1616,7 @@ class TestIndexing(unittest.TestCase): self.check_schedule(xt, 6) np.testing.assert_equal(xt.numpy(), 6) - @unittest.skip("TODO: support pads in graph_rewrite") + @unittest.skip("TODO: break the schedule if dims don't match") def test_advanced_simple_indexing_combined(self): X = Tensor.arange(16).reshape(4, 4) xt = X[1:2, [1, 2]] @@ -2060,7 +2028,7 @@ class TestView(unittest.TestCase): run_schedule(s) self.assertEqual(other_child.tolist(), [2, 3, 4]) -def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic_simple) +def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, merge_views+symbolic_simple) class TestSimplifier(unittest.TestCase): def test_sink_childless_const(self): x = Tensor(0) @@ -2302,13 +2270,13 @@ class TestTensorUOpSpec(unittest.TestCase): unsafe_push_views = PatternMatcher([ (UPat.cvar("root").view(name="view"), lambda root,view: root.replace(src=tuple(x.view(view.st) for x in root.src))), ]) - a.lazydata = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views+unsafe_push_views) + a.lazydata = graph_rewrite(a.lazydata.sink(), merge_views+merge_views+unsafe_push_views) with self.assertRaisesRegex(RuntimeError, "UOp verification failed"): a.schedule() def test_expanded_const_ok(self): a = Tensor.ones((4, 4)) - t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views) + t = graph_rewrite(a.lazydata.sink(), merge_views+merge_views) create_schedule_with_vars(t) # NOTE: changing symbolic CONST VIEWs is not allowed @@ -2316,7 +2284,7 @@ class TestTensorUOpSpec(unittest.TestCase): def test_symbolic_shape_ok(self): a = Tensor.ones(4) vi = UOp.variable("i", 1, 10).bind(4) - a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views) + a.lazydata = graph_rewrite(a.reshape(vi).sum().lazydata, merge_views+merge_views) a.schedule() class TestBufferUOp(unittest.TestCase): @@ -2348,7 +2316,7 @@ class TestBufferUOp(unittest.TestCase): def test_buffer_view_not_allowed(self): permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1) - merged = graph_rewrite(permuted_view.lazydata, remove_movement_ops) + merged = graph_rewrite(permuted_view.lazydata, merge_views) with self.assertRaisesRegex(AssertionError, "VIEW only works here if it's contiguous"): merged.buffer # cannot access Buffer of a non contiguous VIEW diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 206ee88750..e396f99e6f 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -48,7 +48,7 @@ def split_reduceop(reduce:UOp, x:UOp): sym = symbolic_simple+PatternMatcher([ # UOp with size 0 is zero (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ - and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), + and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), # DETACH and CONTIGUOUS_BACKWARD are NOOPs here (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), # reduce of size 0 is the identity element @@ -64,7 +64,7 @@ sym = symbolic_simple+PatternMatcher([ (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), # remove cast to image when it's already a contiguous image - (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)), + (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), # make things that can't be images not images (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType) @@ -79,7 +79,7 @@ sym = symbolic_simple+PatternMatcher([ lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None), # remove CONST/BIND/VIEW from SINK (UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src) - if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None), + if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None), ]) # support for using a contiguous permuted view instead of the parent view if one exists