diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 8db737e2b8..183974ee1b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -14,6 +14,8 @@ from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.cstyle import CUDARenderer MOCKGPU = getenv("MOCKGPU") +from tinygrad.uop.ops import print_uops # noqa: F401 # pylint: disable=unused-import + class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): # NOTE: this realize exists because Tensor.numpy calls .contiguous() internally @@ -38,6 +40,22 @@ class TestLinearizer(unittest.TestCase): np.testing.assert_equal(a.numpy(), ta) np.testing.assert_equal(b.numpy(), tb) + @unittest.skip("TODO: some backends insert more casts") + def test_cast_there_and_back(self): + tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize() + out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2 + ast = helper_linearizer_opt(out) + uops = get_program(ast, opts=[]).uops + self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) + + @unittest.expectedFailure + def test_cast_back_and_there(self): + tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize() + out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2 + ast = helper_linearizer_opt(out) + uops = get_program(ast, opts=[]).uops + self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0) + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx") def test_late_bias_load(self): img = Tensor.empty(1, 3, 16, 16) @@ -491,7 +509,7 @@ def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]: return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] def reset_bufs(bufs:list[Buffer]): - for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled + for buf in bufs: buf.copyin(np.zeros((buf.size*buf.dtype.itemsize,), dtype=np.uint8).data) def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[], apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):