mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-07 21:26:21 -05:00
Add tests for casting (#724)
* Add tests for casting * Skip half_matmul_upcast when TORCH=1 * Fix promotion on torch * Fix spacing
This commit is contained in:
@@ -24,6 +24,22 @@ class TestDtype(unittest.TestCase):
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
|
||||
|
||||
def test_half_mul(self):
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
b = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
c = a*b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
|
||||
|
||||
def test_half_matmul(self):
|
||||
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
|
||||
b = Tensor.eye(2, dtype=dtypes.float16)
|
||||
c = a@b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
|
||||
|
||||
def test_upcast_float(self):
|
||||
# NOTE: there's no downcasting support
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16).float()
|
||||
@@ -41,5 +57,21 @@ class TestDtype(unittest.TestCase):
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
|
||||
|
||||
def test_half_mul_upcast(self):
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
b = Tensor([1,2,3,4], dtype=dtypes.float32)
|
||||
c = a*b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
|
||||
|
||||
def test_half_matmul_upcast(self):
|
||||
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
|
||||
b = Tensor.eye(2, dtype=dtypes.float32)
|
||||
c = a@b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user