base change scheduling spec (#4613)

* spec and kernel cnt

* dont use half

* skip half
This commit is contained in:
qazal
2024-05-16 18:30:49 +08:00
committed by GitHub
parent 65f7e3b3ab
commit 0b464df605

View File

@@ -807,5 +807,40 @@ class TestSchedule(unittest.TestCase):
with self.assertRaises(AssertionError):
np.testing.assert_equal(out.numpy(), [2, 0])
def test_base_change_shrink_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:-1, :-1]
d = c.pad(((0, 1), (0, 1))) * 2
run_schedule(check_schedule(d, 1))
with self.assertRaises(AssertionError): # TODO unsafe pads
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
def test_base_change_expand_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:, None, :]
d = c.pad(((0, 0), (1, 1), (0, 0))) * 2
run_schedule(check_schedule(d, 2))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
# TODO like openpilot with imagef
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_base_change_expand_expand(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.cast(dtypes.half).expand(2, 4, 4)
c = b.cast(dtypes.int).expand(2, 2, 4, 4)
run_schedule(check_schedule(c, 2))
np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32))
def test_base_change_pad_expand(self):
a = Tensor.full((4, 4), 1.).contiguous().realize()
b = Tensor.full((4, 4), 2.).contiguous().realize()
c = (a + b).pad(((1, 1), (1, 1)))
d = c.cast(dtypes.int).expand((2, 6, 6)) * 4
run_schedule(check_schedule(d, 2))
c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0)
np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4)
if __name__ == '__main__':
unittest.main(verbosity=2)