tests for cast there and back (#13195)

* fix cast folding in llama

* dtypes that work everywhere

* Skip test_cast_there_and_back for backend casts

Skip test due to backend casting issues.
This commit is contained in:
George Hotz
2025-11-14 16:56:09 -08:00
committed by GitHub
parent 6c5fa349e1
commit 567066f51f

View File

@@ -14,6 +14,8 @@ from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import CUDARenderer
MOCKGPU = getenv("MOCKGPU")
from tinygrad.uop.ops import print_uops # noqa: F401 # pylint: disable=unused-import
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
@@ -38,6 +40,22 @@ class TestLinearizer(unittest.TestCase):
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
@unittest.skip("TODO: some backends insert more casts")
def test_cast_there_and_back(self):
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2
ast = helper_linearizer_opt(out)
uops = get_program(ast, opts=[]).uops
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
@unittest.expectedFailure
def test_cast_back_and_there(self):
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2
ast = helper_linearizer_opt(out)
uops = get_program(ast, opts=[]).uops
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx")
def test_late_bias_load(self):
img = Tensor.empty(1, 3, 16, 16)
@@ -491,7 +509,7 @@ def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]:
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
def reset_bufs(bufs:list[Buffer]):
for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
for buf in bufs: buf.copyin(np.zeros((buf.size*buf.dtype.itemsize,), dtype=np.uint8).data)
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):