enable few tests for PTX test_dtype (#12445)

This commit is contained in:
chenyu
2025-10-03 20:56:30 +08:00
committed by GitHub
parent 0b534f71c2
commit 9f2b69b870

View File

@@ -7,7 +7,7 @@ from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad import Device, Tensor, dtypes
from hypothesis import assume, given, settings, strategies as strat
from hypothesis import given, settings, strategies as strat
from test.helpers import rand_for_dtype
from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX
import pytest
@@ -52,8 +52,6 @@ def _test_cast(a:Tensor, target_dtype:DType):
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
@@ -294,7 +292,6 @@ class TestInt8DType(TestDType):
def test_int8_to_uint16_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4])
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken in ptx")
def test_bitcast_alt(self):
a = Tensor([72, -90, 27, 40, -53, 70, 96, 51], dtype=dtypes.int8).bitcast(dtypes.short)
self.assertListEqual(a.tolist(), [-22968, 10267, 18123, 13152])
@@ -308,8 +305,6 @@ class TestUint8DType(TestDType):
class TestBitCast(unittest.TestCase):
@given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats))
def test_shape_change_bitcast(self, dt1, dt2):
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
assume(not (isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
if dt2 in dtypes.fp8s: