do not use getenv('PTX') in tests (#12095)

* test without ptx

* fix tests

* fix test

* linters
This commit is contained in:
nimlgen
2025-09-10 14:04:07 +03:00
committed by GitHub
parent 0e420e68b4
commit 551560b87c
9 changed files with 30 additions and 22 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad import Device, Tensor, dtypes
from hypothesis import assume, given, settings, strategies as strat
from test.helpers import rand_for_dtype
@@ -49,7 +50,7 @@ def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype))
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected.tolist())
@@ -100,7 +101,7 @@ class TestDType(unittest.TestCase):
))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
@unittest.skipIf(getenv("PTX"), "skip for now")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "skip for now")
def test_uint_overflow(self):
if not dtypes.is_unsigned(self.DTYPE): raise unittest.SkipTest("only for unsigned")
v = dtypes.max(self.DTYPE)
@@ -255,7 +256,8 @@ class TestFloatDType(TestDType):
class TestDoubleDType(TestDType):
DTYPE = dtypes.double
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or getenv("PTX"), "conversion not supported on CI CUDA and PTX") # TODO: why not?
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \
isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "conversion not supported on CI CUDA and PTX") # TODO: why not?
def test_float64_increased_precision(self):
for func in [
lambda t: t.exp(),
@@ -279,21 +281,21 @@ class TestDoubleDType(TestDType):
class TestInt8DType(TestDType):
DTYPE = dtypes.int8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
@unittest.skipIf(getenv("CUDA",0)==1 or isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "cuda saturation works differently")
def test_int8_to_uint8_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])
def test_int8_to_uint16_negative(self):
_test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4])
@unittest.skipIf(getenv("PTX"), "broken in ptx")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken in ptx")
def test_bitcast_alt(self):
a = Tensor([72, -90, 27, 40, -53, 70, 96, 51], dtype=dtypes.int8).bitcast(dtypes.short)
self.assertListEqual(a.tolist(), [-22968, 10267, 18123, 13152])
class TestUint8DType(TestDType):
DTYPE = dtypes.uint8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
@unittest.skipIf(getenv("CUDA",0)==1 or isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "cuda saturation works differently")
def test_uint8_to_int8_overflow(self):
_test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@@ -301,7 +303,7 @@ class TestBitCast(unittest.TestCase):
@given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats))
def test_shape_change_bitcast(self, dt1, dt2):
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
assume(not (isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())