mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
late fusion tests, early merge view GroupOp.Buffer [pr] (#7577)
* test_late_fusion_double_transpose * early merge view buffer ops
This commit is contained in:
@@ -1360,6 +1360,12 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 4))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_late_fusion_double_transpose(self):
|
||||
with Context(DEBUG=0): a = Tensor.randn(32, 16, 1).realize()
|
||||
compare = (a.expand(32, 16, 16).sum((2,), keepdim=True).T+2).T.contiguous()
|
||||
run_schedule(check_schedule(compare, 1))
|
||||
np.testing.assert_allclose(compare.numpy(), (np.broadcast_to(a.numpy(), (32, 16, 16)).sum(axis=2, keepdims=True).T+2).T, atol=1e-4, rtol=1e-4)
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
|
||||
Reference in New Issue
Block a user