mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add tests for casting (#724)
* Add tests for casting * Skip half_matmul_upcast when TORCH=1 * Fix promotion on torch * Fix spacing
This commit is contained in:
@@ -24,6 +24,22 @@ class TestDtype(unittest.TestCase):
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
|
||||
|
||||
def test_half_mul(self):
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
b = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
c = a*b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
|
||||
|
||||
def test_half_matmul(self):
|
||||
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
|
||||
b = Tensor.eye(2, dtype=dtypes.float16)
|
||||
c = a@b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float16
|
||||
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
|
||||
|
||||
def test_upcast_float(self):
|
||||
# NOTE: there's no downcasting support
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16).float()
|
||||
@@ -41,5 +57,21 @@ class TestDtype(unittest.TestCase):
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [2,4,6,8])
|
||||
|
||||
def test_half_mul_upcast(self):
|
||||
a = Tensor([1,2,3,4], dtype=dtypes.float16)
|
||||
b = Tensor([1,2,3,4], dtype=dtypes.float32)
|
||||
c = a*b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [1,4,9,16])
|
||||
|
||||
def test_half_matmul_upcast(self):
|
||||
a = Tensor([[1,2],[3,4]], dtype=dtypes.float16)
|
||||
b = Tensor.eye(2, dtype=dtypes.float32)
|
||||
c = a@b
|
||||
print(c.numpy())
|
||||
assert c.dtype == dtypes.float32
|
||||
np.testing.assert_allclose(c.numpy(), [[1,2],[3,4]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -11,7 +11,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(a.dtype), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
|
||||
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg)
|
||||
}}
|
||||
|
||||
Reference in New Issue
Block a user