mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
AutoCastType tests for fp8s/bf16 (#12084)
This commit is contained in:
@@ -439,8 +439,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@@ -471,8 +473,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
|
||||
|
||||
@@ -486,8 +490,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
|
||||
|
||||
|
||||
Reference in New Issue
Block a user