diff --git a/test/unit/test_dtype_spec.py b/test/unit/test_dtype_spec.py index 63990f7fd3..175edf851a 100644 --- a/test/unit/test_dtype_spec.py +++ b/test/unit/test_dtype_spec.py @@ -378,7 +378,7 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64 assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64 # similar to jax but we don't use weak type - assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float16 + assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.fp8e4m3 assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32 assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64 @@ -387,6 +387,14 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16 assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16 assert least_upper_dtype(dtypes.fp8e4m3, dtypes.fp8e5m2) == dtypes.half + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.bfloat16) == dtypes.bfloat16 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.bfloat16) == dtypes.bfloat16 + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.float16) == dtypes.float16 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.float16) == dtypes.float16 + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.int64) == dtypes.fp8e4m3 + assert least_upper_dtype(dtypes.fp8e4m3, dtypes.uint64) == dtypes.fp8e4m3 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.int64) == dtypes.fp8e5m2 + assert least_upper_dtype(dtypes.fp8e5m2, dtypes.uint64) == dtypes.fp8e5m2 class TestAutoCastType(unittest.TestCase): def setUp(self): diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 37fe4608f6..b94bd3f6a4 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -177,8 +177,8 @@ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html # we don't support weak type and complex type promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], - dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], - dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], + dtypes.int64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], + dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.fp8e4m3, dtypes.fp8e5m2], dtypes.fp8e5m2: [dtypes.float16, dtypes.bfloat16], dtypes.fp8e4m3: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }