From e4846771b2b53fe6f00763345224810a0f42e614 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 16 Oct 2023 17:50:11 -0700 Subject: [PATCH] Revert "limit metal buffers and revert the 207 fix (try 2) (#2088)" This reverts commit 5e24dc5a95d0d80dd7df74086abf474edc637060. --- test/test_schedule.py | 12 ------------ tinygrad/lazy.py | 12 ++++-------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 0d97d64f90..d5dc2b217c 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -326,17 +326,5 @@ class TestSchedule(unittest.TestCase): out = x.to('cpu') check_schedule(out, 0, filter_loadops=False) - @unittest.skipUnless(Device.DEFAULT == "METAL", "only for metal") - def test_metal_limit_buffers(self): - t = sum([Tensor([1,2,3,4]) for _ in range(40)]) - for si in t.lazydata.schedule(): - assert len(si.inputs) <= 30 - - @unittest.skipUnless(Device.DEFAULT == "METAL", "only for metal") - def test_metal_dont_limit_same_buffers(self): - bt = Tensor(list(range(1, 100))) - out = sum([bt[i:i+2] for i in range(1,40)]) - check_schedule(out, 1) - if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index c97553654b..c073f13ae9 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -202,9 +202,9 @@ class LazyBuffer: # *** elementwise ops *** - def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: + def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: # srcs includes self - srcs:Tuple[LazyBuffer, ...] = (self,)+_srcs + srcs = (self,)+srcs # if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs) @@ -225,13 +225,9 @@ class LazyBuffer: if MERGE_ELEMENTWISE_OPS: # remove the buffers from any (childless) BinaryOps that feed into this - merged_srcs:Tuple[Union[LazyOp, LazyBuffer], ...] = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore - # NOTE: this is incompete, you can still fuse with reduce ops and exceed the limit - merged_srcs = merged_srcs if self.device != "METAL" or len(dedup([y.base for x in merged_srcs for y in x.buffers])) < 30 else srcs - else: - merged_srcs = srcs + srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore - return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, merged_srcs, arg), out_dtype) + return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype) # *** reduce ops ***