From 4f3f55328bbca566ac8215ebe24d01c3f09f3a59 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 12 Mar 2026 09:35:20 +0800 Subject: [PATCH] do not patch on invalid tensor tests (#15226) * do not patch on invalid tensor tests * cleanup --- test/unit/test_invalid_tensor.py | 43 +++++++++++--------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/test/unit/test_invalid_tensor.py b/test/unit/test_invalid_tensor.py index 5b64cd5dc1..756a61d579 100644 --- a/test/unit/test_invalid_tensor.py +++ b/test/unit/test_invalid_tensor.py @@ -1,27 +1,20 @@ -import math, unittest -from unittest.mock import patch +import unittest from tinygrad import Tensor -from tinygrad.device import Device from tinygrad.dtype import Invalid, dtypes -from tinygrad.helpers import unwrap_class_type +from tinygrad.engine.realize import run_schedule class TestInvalidTensor(unittest.TestCase): def _invalid_test_helper(self, out, expected): - before = None - original_call = (runtime_cls:=unwrap_class_type(Device[Device.DEFAULT].runtime)).__call__ + sched = out.schedule() + buf = out.uop.buffer + buf.allocate() + sentinel = memoryview(bytearray(b'\x42' * buf.nbytes)) + buf.copyin(sentinel) + before = buf.as_memoryview().cast(out.dtype.fmt).tolist() + run_schedule(sched) + ret = buf.as_memoryview().cast(out.dtype.fmt).tolist() - def patched_call(self_prg, *bufs, **kwargs): - nonlocal before - before = Device[Device.DEFAULT].allocator._as_buffer(bufs[0]).cast(out.dtype.fmt).tolist() - return original_call(self_prg, *bufs, **kwargs) - - with patch.object(runtime_cls, '__call__', patched_call): ret = out.tolist() - - for i,v in enumerate(expected): - if v is None: assert before[i] == ret[i] or (math.isnan(before[i]) and math.isnan(ret[i])) - else: assert ret[i] == v - - return before, ret + for i,v in enumerate(expected): self.assertEqual(ret[i], before[i] if v is None else v) def test_where_x_invalid(self): mask = Tensor.arange(4) < 2 @@ -37,11 +30,7 @@ class TestInvalidTensor(unittest.TestCase): mask = Tensor.arange(6).reshape(2, 3) < 3 vals = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) out = mask.where(vals, Invalid) - before, ret = self._invalid_test_helper(out, []) - assert ret[0] == [1.0, 2.0, 3.0] - assert before[3] == ret[1][0] or (math.isnan(before[3]) and math.isnan(ret[1][0])) - assert before[4] == ret[1][1] or (math.isnan(before[4]) and math.isnan(ret[1][1])) - assert before[5] == ret[1][2] or (math.isnan(before[5]) and math.isnan(ret[1][2])) + self._invalid_test_helper(out, [1.0, 2.0, 3.0, None, None, None]) def test_where_invalid_int(self): mask = Tensor.arange(3) < 2 @@ -89,8 +78,7 @@ class TestInvalidTensor(unittest.TestCase): def test_where_reduce_always_true(self): mask = Tensor.arange(4) < 9 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).sum() - before, ret = self._invalid_test_helper(out, []) - assert ret == 10.0 + self._invalid_test_helper(out, [10.0]) def test_invalid_unary(self): mask = Tensor.arange(4) < 2 @@ -110,10 +98,7 @@ class TestInvalidTensor(unittest.TestCase): def test_invalid_reshape(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).reshape(2,2) - before, ret = self._invalid_test_helper(out, []) - assert ret[0] == [1.0, 2.0] - assert ret[1][0] == before[2] or (math.isnan(ret[1][0]) and math.isnan(before[2])) - assert ret[1][1] == before[3] or (math.isnan(ret[1][1]) and math.isnan(before[3])) + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_invalid_cast(self): mask = Tensor.arange(4) < 2