No pyint again (#7156)

* Revert "bring back pyint (#7150)"

This reverts commit 37e83ca6fc.

* remove truncate in const folding

* truncate_output=False
This commit is contained in:
chenyu
2024-10-19 13:48:59 -04:00
committed by GitHub
parent 30989fb459
commit f511ad9103
9 changed files with 19 additions and 26 deletions

View File

@@ -29,7 +29,6 @@ def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.pyint and device != "PYTHON": return False
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))

View File

@@ -14,7 +14,7 @@ pytestmark = pytest.mark.filterwarnings("ignore")
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint'])
core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
@@ -22,7 +22,7 @@ dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_sup
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
# dont cast internal dtypes
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint']
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
@@ -806,7 +806,7 @@ class TestTensorMethod(unittest.TestCase):
class TestDtypeUsage(unittest.TestCase):
def test_max_w_alu(self):
for d in dtype_ints:
for d in dtypes.ints:
t = Tensor([[1, 2], [3, 4]], dtype=d)
(t*t).max().item()

View File

@@ -1315,7 +1315,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD", "METAL"])
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"])
if __name__ == '__main__':
unittest.main()

View File

@@ -230,6 +230,9 @@ class TestExecALU(TestUOps):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
# test no truncate
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250), truncate_output=False), 500)
class TestConstantFolding(unittest.TestCase):
def test_cast_const(self):
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)