Files
tinygrad/test/test_dtype.py
2026-01-30 19:38:41 -05:00

467 lines
21 KiB
Python

import contextlib, unittest, math
import numpy as np
import torch
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI, EMULATED_DTYPES
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from tinygrad import Context, Device, Tensor, dtypes
from tinygrad.uop import Ops
from hypothesis import given, settings, strategies as strat
from test.helpers import rand_for_dtype
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
import pytest
pytestmark = pytest.mark.filterwarnings("ignore")
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
# dont cast internal dtypes
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
if not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
if dtype in (dtypes.long, dtypes.ulong): return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support
else: return []
return dts
def _to_torch_storage_type(dtype:DType):
if dtype == dtypes.bfloat16: return torch.float32
if dtype in dtypes.fp8s: return torch.float32
return _to_torch_dtype(dtype)
def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
na = a.numpy()
if DEBUG >= 2: print(na, na.dtype, a.uop.base.realized)
try:
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
except AssertionError as e:
raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
def _test_op(fxn, target_dtype:DType, target):
_assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType):
if a.is_floating_point() and dtypes.is_unsigned(target_dtype):
# converting negative float to unsigned integer is undefined
a = a.abs()
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
if target_dtype in dtypes.fp8s: expected = [fp8_to_float(x, target_dtype) for x in expected]
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
class TestDType(unittest.TestCase):
DTYPE: Any = None
DATA: Any = None
@classmethod
def setUpClass(cls):
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
def test_to_np(self):
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
def test_casts_to(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
_test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE)
def test_casts_from(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
_test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype)
def test_same_size_ops(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize == self.DTYPE.itemsize:
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
def test_upcast_ops(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize > self.DTYPE.itemsize:
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
def test_upcast_to_ops(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize < self.DTYPE.itemsize:
_test_ops(a_dtype=dtype, b_dtype=self.DTYPE)
def test_bitcast(self):
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype != dtypes.bool:
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype)
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
def test_uint_overflow(self):
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
v = dtypes.max(self.DTYPE)
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
def test_dtypes_DTYPES_DICT(self):
self.assertIn("float", DTYPES_DICT)
self.assertIn("float32", DTYPES_DICT)
self.assertEqual(len(DTYPES_DICT), 26)
self.assertTrue(all(isinstance(value, DType) for value in DTYPES_DICT.values()))
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in DTYPES_DICT.values() if _to_np_dtype(value) is not None))
def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data).astype(dt)
tensor = Tensor(arr)
if not is_dtype_supported(tensor.dtype): continue
tin = tensor.numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
def test_finfo(self):
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
info = np.finfo(_to_np_dtype(self.DTYPE))
self.assertEqual(info.bits, self.DTYPE.bitsize)
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
class TestFp8s(unittest.TestCase):
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
class TestFp8sConversions(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
def test_float_to_fp8e4m3(self, x):
np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.float8_e4m3fn).view(torch.uint8).item())
def test_float_to_fp8e4m3_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 126)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e4m3), 127)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-FP8E4M3_MAX*1.01, dtypes.fp8e4m3), 254)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e4m3), 255)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e4m3), 127)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e4m3), 255)
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E5M2_MAX, max_value=FP8E5M2_MAX))
def test_float_to_fp8e5m2(self, x):
np.testing.assert_equal(float_to_fp8(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.float8_e5m2).view(torch.uint8).item())
def test_float_to_fp8e5m2_extreme_values(self):
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 123)
np.testing.assert_equal(float_to_fp8(math.inf, dtypes.fp8e5m2), 124)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-FP8E5M2_MAX*1.01, dtypes.fp8e5m2), 251)
np.testing.assert_equal(float_to_fp8(-math.inf, dtypes.fp8e5m2), 252)
np.testing.assert_equal(float_to_fp8(math.nan, dtypes.fp8e5m2), 126)
np.testing.assert_equal(float_to_fp8(-math.nan, dtypes.fp8e5m2), 254)
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e4m3_to_float(self, x):
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e4m3), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e4m3fn).float().item())
@given(strat.integers(min_value=0, max_value=255))
def test_fp8e5m2_to_float(self, x):
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2).float().item())
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
class TestBFloat16(unittest.TestCase):
def test_bf16_creation_numpy(self):
data = [-1, 1, 2]
t = Tensor(data, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
tnp = t.numpy()
assert tnp.dtype == np.float32
np.testing.assert_allclose(tnp, np.array(data))
def test_bf16_ones(self):
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
def test_bf16_eye(self):
t = Tensor.eye(3, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.eye(3))
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
def test_float_to_bf16(self):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
def test_bf16(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
t.realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) and is_dtype_supported(dtypes.float16), "bfloat16 or float16 not supported")
class TestBFloat16DTypeCast(unittest.TestCase):
def test_f16_to_bf16_conversion(self):
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
converted_tensor = original_tensor.cast(dtypes.bfloat16)
self.assertEqual(converted_tensor.dtype, dtypes.bfloat16)
back_to_float32 = converted_tensor.cast(dtypes.float32)
original_to_float32 = original_tensor.cast(dtypes.float32)
np.testing.assert_allclose(back_to_float32.numpy(), original_to_float32.numpy(), rtol=1e-2, atol=1e-3)
def test_f16_to_bf16_edge_cases(self):
edge_cases = Tensor([0.0, -0.0, float('inf'), float('-inf'), float('nan')], dtype=dtypes.float16)
converted = edge_cases.cast(dtypes.bfloat16).cast(dtypes.float32)
np.testing.assert_equal(converted.numpy(), edge_cases.cast(dtypes.float32).numpy())
def test_f16_to_bf16_range_precision(self):
large_value = Tensor([65504.0], dtype=dtypes.float16) # Max representable in float16
small_value = Tensor([6.1035e-5], dtype=dtypes.float16) # Smallest positive normal float16
large_converted = large_value.cast(dtypes.bfloat16).cast(dtypes.float32)
small_converted = small_value.cast(dtypes.bfloat16).cast(dtypes.float32)
np.testing.assert_allclose(large_converted.numpy(), large_value.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
np.testing.assert_equal(small_converted.numpy(), small_value.cast(dtypes.float32).numpy())
def test_f16_to_bf16_randomized(self):
np.random.seed(42) # For reproducibility
random_values = Tensor(np.random.uniform(-65504, 65504, 1000), dtype=dtypes.float16)
converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
class TestHalfDType(TestDType): DTYPE = dtypes.half
class TestFloatDType(TestDType):
DTYPE = dtypes.float
def test_float_to_uint(self):
_test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32,
[0, 0, 1])
class TestDoubleDType(TestDType):
DTYPE = dtypes.double
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \
isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "conversion not supported on CI CUDA, PTX, and NIR") # TODO: why not?
def test_float64_increased_precision(self):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12)
def test_float64_to_float32_cast_inf(self):
_test_op(lambda: Tensor([3.4e40, 3.4e38, 1, 0], dtype=dtypes.float64).cast(dtypes.float32),
dtypes.float32, [float('inf'), 3.4e38, 1, 0])
class TestInt8DType(TestDType):
DTYPE = dtypes.int8
@unittest.skipIf(getenv("CUDA",0)==1 or isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "cuda saturation works differently")
def test_int8_to_uint8_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
def test_int8_to_uint16_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4])
def test_bitcast_alt(self):
a = Tensor([72, -90, 27, 40, -53, 70, 96, 51], dtype=dtypes.int8).bitcast(dtypes.short)
self.assertListEqual(a.tolist(), [-22968, 10267, 18123, 13152])
class TestUint8DType(TestDType):
DTYPE = dtypes.uint8
@unittest.skipIf(getenv("CUDA",0)==1 or isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "cuda saturation works differently")
def test_uint8_to_int8_overflow(self):
_test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
class TestBitCast(unittest.TestCase):
@given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats))
def test_shape_change_bitcast(self, dt1, dt2):
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
if dt2 in dtypes.fp8s:
expected = torch.tensor([fp8_to_float(x, dt2) for x in expected.view(-1).tolist()]).view_as(expected)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
def test_shape_change_bitcast_exceptions(self):
with self.assertRaises(RuntimeError):
# should fail because 3 int8 is 3 bytes but float16 is two and 3 isn't a multiple of 2
Tensor.empty((3,), dtype=dtypes.int8).bitcast(dtypes.float16)
with self.assertRaises(RuntimeError):
# should fail because backprop through bitcast is undefined
Tensor.empty((4,), dtype=dtypes.int8, requires_grad=True).bitcast(dtypes.float16)
def test_bitcast_float_to_int32(self):
a = Tensor([1.,2,3])
b = a.bitcast(dtypes.int32)
assert b.numpy()[0] == 0x3f800000
def test_bitcast_upcasted(self):
a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000
b = a.bitcast(dtypes.float32)
assert b.numpy()[0,0] == 1.
class TestInt16DType(TestDType): DTYPE = dtypes.int16
class TestUint16DType(TestDType):
DTYPE = dtypes.uint16
def test_uint16_to_int8_overflow(self):
_test_op(lambda: Tensor([2**16-1, 2**16-2, 1, 0], dtype=dtypes.uint16).cast(dtypes.int8), dtypes.int8, [-1, -2, 1, 0])
class TestInt32DType(TestDType): DTYPE = dtypes.int32
class TestUint32DType(TestDType): DTYPE = dtypes.uint32
class TestInt64DType(TestDType): DTYPE = dtypes.int64
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
class TestEmulatedInt64DType(TestInt64DType):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestUint64DType(TestDType):
DTYPE = dtypes.uint64
def test_uint64_load(self):
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1
@unittest.skipUnless(Ops.SHL in Device[Device.DEFAULT].renderer.code_for_op, "long decomp requires bitshift")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX does indexing math with longs")
class TestEmulatedUInt64DType(TestUint64DType):
@classmethod
def setUpClass(cls):
cls.stack = contextlib.ExitStack()
cls.stack.enter_context(Context(EMULATED_DTYPES="long"))
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr().vec(4)
dt2 = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt1, dt2)
self.assertEqual(str(dt1), str(dt2))
def test_scalar(self):
dt = dtypes.float.vec(4).ptr().scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).ptr().vec(4).scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).scalar()
self.assertEqual(dt, dtypes.float)
def test_serialize(self):
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt, eval(str(dt)))
def test_vec_ptr_sz(self):
dt = dtypes.float.ptr(1024).vec(4)
self.assertEqual(dt, eval(str(dt)))
self.assertEqual(str(dt), "dtypes.float.ptr(1024).vec(4)")
def test_vcount(self):
dt = dtypes.float.ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 1)
dt = dtypes.float.vec(4).ptr()
self.assertEqual(dt.vcount, 1)
self.assertEqual(dt.v, 1)
self.assertEqual(dt.count, 4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 4)
class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self):
result = []
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.sin(),
]:
t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0]))
result.append(t.numpy().sum())
assert all(result)
class TestTensorMethod(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_abs_diff(self, dt):
if dt == dtypes.bool or not is_dtype_supported(dt): return
a, b = Tensor([2], dtype=dt), Tensor([1], dtype=dt)
ret = (a - b).abs()
np.testing.assert_allclose(ret.numpy(), np.abs(a.numpy()-b.numpy()))
class TestDtypeUsage(unittest.TestCase):
def test_max_w_alu(self):
for d in dtypes.ints:
if is_dtype_supported(d):
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
class TestOpsBFloat16(unittest.TestCase):
def test_cast(self):
# TODO: helper_test_op breaks in unrelated part
data = [60000.0, 70000.0, 80000.0]
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
# some CPUs there is no native bfloat16 sqrt
@unittest.skipIf(Device.DEFAULT == "CPU", "no approximation")
def test_no_approximation(self):
data = [326.0, 339.0, 10603200512.0]
expected = torch.tensor(data, dtype=torch.bfloat16).sqrt().float().numpy()
np.testing.assert_allclose(Tensor(data, dtype=dtypes.bfloat16).sqrt().numpy(), expected)
if __name__ == '__main__':
unittest.main()