AutoCastType tests for fp8s/bf16 (#12084)

This commit is contained in:
b1tg
2025-09-09 23:33:01 +08:00
committed by GitHub
parent 5e76eff26d
commit 14faf7a5c0
2 changed files with 11 additions and 5 deletions

View File

@@ -439,8 +439,10 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@@ -471,8 +473,10 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
@@ -486,8 +490,10 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64