Add tests for casting (#724)

* Add tests for casting

* Skip half_matmul_upcast when TORCH=1

* Fix promotion on torch

* Fix spacing
This commit is contained in:
Jacky Lee
2023-03-23 08:02:52 -07:00
committed by GitHub
parent 68e45fca18
commit e009b6f341
2 changed files with 34 additions and 2 deletions

View File

@@ -24,6 +24,22 @@ class TestDtype(unittest.TestCase):
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
def test_half_mul(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float16)
c = a*b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
def test_half_matmul(self):
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
b = Tensor.eye(2, dtype=dtypes.float16)
c = a@b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
def test_upcast_float(self):
# NOTE: there's no downcasting support
a = Tensor([1,2,3,4], dtype=dtypes.float16).float()
@@ -41,5 +57,21 @@ class TestDtype(unittest.TestCase):
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
def test_half_mul_upcast(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float32)
c = a*b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
def test_half_matmul_upcast(self):
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
b = Tensor.eye(2, dtype=dtypes.float32)
c = a@b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
if __name__ == '__main__':
unittest.main()
unittest.main()