diff --git a/test/test_dtype.py b/test/test_dtype.py index dc4766b71a..f46f90de8a 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -24,6 +24,22 @@ class TestDtype(unittest.TestCase): assert c.dtype == dtypes.float16 np.testing.assert_allclose(c.numpy(), [2,4,6,8]) + def test_half_mul(self): + a = Tensor([1,2,3,4], dtype=dtypes.float16) + b = Tensor([1,2,3,4], dtype=dtypes.float16) + c = a*b + print(c.numpy()) + assert c.dtype == dtypes.float16 + np.testing.assert_allclose(c.numpy(), [1,4,9,16]) + + def test_half_matmul(self): + a = Tensor([[1,2],[3,4]], dtype=dtypes.float16) + b = Tensor.eye(2, dtype=dtypes.float16) + c = a@b + print(c.numpy()) + assert c.dtype == dtypes.float16 + np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]]) + def test_upcast_float(self): # NOTE: there's no downcasting support a = Tensor([1,2,3,4], dtype=dtypes.float16).float() @@ -41,5 +57,21 @@ class TestDtype(unittest.TestCase): assert c.dtype == dtypes.float32 np.testing.assert_allclose(c.numpy(), [2,4,6,8]) + def test_half_mul_upcast(self): + a = Tensor([1,2,3,4], dtype=dtypes.float16) + b = Tensor([1,2,3,4], dtype=dtypes.float32) + c = a*b + print(c.numpy()) + assert c.dtype == dtypes.float32 + np.testing.assert_allclose(c.numpy(), [1,4,9,16]) + + def test_half_matmul_upcast(self): + a = Tensor([[1,2],[3,4]], dtype=dtypes.float16) + b = Tensor.eye(2, dtype=dtypes.float32) + c = a@b + print(c.numpy()) + assert c.dtype == dtypes.float32 + np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]]) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 1430ebd5d9..55223e6c08 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -11,7 +11,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), - FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)), + FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]), MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg) }}