mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
push permutes through fused reduces (#10628)
* fix pushing reshapes through reduceops * reduceop_view_right should assert on ndims mismatch * update that, view.reshape asserts it
This commit is contained in:
@@ -89,14 +89,20 @@ class TestSchedule(unittest.TestCase):
|
||||
with self.assertRaises(RecursionError):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=0): self.test_arange_avgpool2d(kcount=1)
|
||||
|
||||
# grouper error
|
||||
@unittest.expectedFailure
|
||||
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
|
||||
# all permutes, reshapes, expands and shrinks push through the reduce
|
||||
def test_arange_sum(self):
|
||||
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [1, 5, 9])
|
||||
|
||||
def test_permute_arange(self):
|
||||
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
|
||||
with Context(FUSE_ARANGE=1):
|
||||
run_schedule(check_schedule(a, 1))
|
||||
self.assertListEqual(a.tolist(), [[15]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
||||
def test_error_on_device_mismatch(self):
|
||||
a = Tensor.empty(10)
|
||||
|
||||
Reference in New Issue
Block a user