mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
dtype promotion helpers (#2724)
* dtype promotion helpers * better tests * space
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
@@ -234,5 +234,37 @@ class TestTypeSpec(unittest.TestCase):
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
|
||||
|
||||
# TODO: better way to write a set of core dtypes?
|
||||
core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]]
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(st.sampled_from(core_types))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
assert least_upper_dtype(dtype) == dtype
|
||||
assert least_upper_dtype(dtype, dtype) == dtype
|
||||
assert least_upper_dtype(dtype, dtype, dtype) == dtype
|
||||
|
||||
@given(st.sampled_from(core_types), st.sampled_from(core_types))
|
||||
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
|
||||
result = least_upper_dtype(dtype1, dtype2)
|
||||
assert result >= dtype1 and result >= dtype2
|
||||
|
||||
def test_dtype_promo(self):
|
||||
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
|
||||
assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
|
||||
assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
|
||||
assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
|
||||
assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
|
||||
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
|
||||
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
|
||||
# special!
|
||||
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float_scalar
|
||||
assert least_upper_dtype(dtypes.float_scalar, dtypes.float16) == dtypes.float16
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
|
||||
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
|
||||
|
||||
assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
|
||||
assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user