Add tests for casting (#724)

* Add tests for casting

* Skip half_matmul_upcast when TORCH=1

* Fix promotion on torch

* Fix spacing
This commit is contained in:
Jacky Lee
2023-03-23 08:02:52 -07:00
committed by GitHub
parent 68e45fca18
commit e009b6f341
2 changed files with 34 additions and 2 deletions

View File

@@ -24,6 +24,22 @@ class TestDtype(unittest.TestCase):
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
def test_half_mul(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float16)
c = a*b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
def test_half_matmul(self):
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
b = Tensor.eye(2, dtype=dtypes.float16)
c = a@b
print(c.numpy())
assert c.dtype == dtypes.float16
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
def test_upcast_float(self):
# NOTE: there's no downcasting support
a = Tensor([1,2,3,4], dtype=dtypes.float16).float()
@@ -41,5 +57,21 @@ class TestDtype(unittest.TestCase):
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
def test_half_mul_upcast(self):
a = Tensor([1,2,3,4], dtype=dtypes.float16)
b = Tensor([1,2,3,4], dtype=dtypes.float32)
c = a*b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
def test_half_matmul_upcast(self):
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
b = Tensor.eye(2, dtype=dtypes.float32)
c = a@b
print(c.numpy())
assert c.dtype == dtypes.float32
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -11,7 +11,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg)
}}