dtype promotion helpers (#2724)

* dtype promotion helpers

* better tests

* space
This commit is contained in:
chenyu
2023-12-11 23:14:23 -05:00
committed by GitHub
parent 0232db294d
commit ef6e942a23
2 changed files with 48 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
@@ -234,5 +234,37 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
# TODO: better way to write a set of core dtypes?
core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]]
class TestTypePromotion(unittest.TestCase):
@given(st.sampled_from(core_types))
def test_self_promo_to_self(self, dtype):
assert least_upper_dtype(dtype) == dtype
assert least_upper_dtype(dtype, dtype) == dtype
assert least_upper_dtype(dtype, dtype, dtype) == dtype
@given(st.sampled_from(core_types), st.sampled_from(core_types))
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
result = least_upper_dtype(dtype1, dtype2)
assert result >= dtype1 and result >= dtype2
def test_dtype_promo(self):
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
# special!
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float_scalar
assert least_upper_dtype(dtypes.float_scalar, dtypes.float16) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
if __name__ == '__main__':
unittest.main()