mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tests for cast there and back (#13195)
* fix cast folding in llama * dtypes that work everywhere * Skip test_cast_there_and_back for backend casts Skip test due to backend casting issues.
This commit is contained in:
@@ -14,6 +14,8 @@ from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
from tinygrad.uop.ops import print_uops # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
|
||||
@@ -38,6 +40,22 @@ class TestLinearizer(unittest.TestCase):
|
||||
np.testing.assert_equal(a.numpy(), ta)
|
||||
np.testing.assert_equal(b.numpy(), tb)
|
||||
|
||||
@unittest.skip("TODO: some backends insert more casts")
|
||||
def test_cast_there_and_back(self):
|
||||
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
|
||||
out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2
|
||||
ast = helper_linearizer_opt(out)
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_cast_back_and_there(self):
|
||||
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
|
||||
out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2
|
||||
ast = helper_linearizer_opt(out)
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx")
|
||||
def test_late_bias_load(self):
|
||||
img = Tensor.empty(1, 3, 16, 16)
|
||||
@@ -491,7 +509,7 @@ def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]:
|
||||
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
def reset_bufs(bufs:list[Buffer]):
|
||||
for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
||||
for buf in bufs: buf.copyin(np.zeros((buf.size*buf.dtype.itemsize,), dtype=np.uint8).data)
|
||||
|
||||
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):
|
||||
|
||||
Reference in New Issue
Block a user