From 2fb8a7f60fdb3fa15241a57d22b2550bcd275d63 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Tue, 10 Mar 2026 20:51:19 -0700 Subject: [PATCH] fix test_invalid_tensor when before values are nan (#15215) --- test/unit/test_invalid_tensor.py | 88 +++++++++++--------------------- 1 file changed, 31 insertions(+), 57 deletions(-) diff --git a/test/unit/test_invalid_tensor.py b/test/unit/test_invalid_tensor.py index 05d50f10ce..5b64cd5dc1 100644 --- a/test/unit/test_invalid_tensor.py +++ b/test/unit/test_invalid_tensor.py @@ -1,4 +1,4 @@ -import unittest +import math, unittest from unittest.mock import patch from tinygrad import Tensor from tinygrad.device import Device @@ -6,7 +6,7 @@ from tinygrad.dtype import Invalid, dtypes from tinygrad.helpers import unwrap_class_type class TestInvalidTensor(unittest.TestCase): - def _realize_and_capture(self, out): + def _invalid_test_helper(self, out, expected): before = None original_call = (runtime_cls:=unwrap_class_type(Device[Device.DEFAULT].runtime)).__call__ @@ -17,151 +17,125 @@ class TestInvalidTensor(unittest.TestCase): with patch.object(runtime_cls, '__call__', patched_call): ret = out.tolist() + for i,v in enumerate(expected): + if v is None: assert before[i] == ret[i] or (math.isnan(before[i]) and math.isnan(ret[i])) + else: assert ret[i] == v + return before, ret def test_where_x_invalid(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_where_invalid_x(self): mask = Tensor.arange(4) < 2 out = mask.where(Invalid, Tensor([1.0, 2.0, 3.0, 4.0])) - before, ret = self._realize_and_capture(out) - assert ret[2] == 3.0 and ret[3] == 4.0 - assert before[0] == ret[0] and before[1] == ret[1] + self._invalid_test_helper(out, [None, None, 3.0, 4.0]) def test_where_invalid_2d(self): mask = Tensor.arange(6).reshape(2, 3) < 3 vals = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) out = mask.where(vals, Invalid) - before, ret = self._realize_and_capture(out) + before, ret = self._invalid_test_helper(out, []) assert ret[0] == [1.0, 2.0, 3.0] - assert before[3] == ret[1][0] and before[4] == ret[1][1] and before[5] == ret[1][2] + assert before[3] == ret[1][0] or (math.isnan(before[3]) and math.isnan(ret[1][0])) + assert before[4] == ret[1][1] or (math.isnan(before[4]) and math.isnan(ret[1][1])) + assert before[5] == ret[1][2] or (math.isnan(before[5]) and math.isnan(ret[1][2])) def test_where_invalid_int(self): mask = Tensor.arange(3) < 2 out = mask.where(Tensor([10, 20, 30]), Invalid) - before, ret = self._realize_and_capture(out) - assert ret[0] == 10 and ret[1] == 20 - assert before[2] == ret[2] + self._invalid_test_helper(out, [10, 20, None]) def test_where_invalid_add(self): mask = Tensor.arange(3) < 2 mixed = mask.where(Tensor([10.0, 20.0, 30.0]), Invalid) out = mixed + Tensor([1.0, 2.0, 3.0]) - before, ret = self._realize_and_capture(out) - assert ret[0] == 11.0 and ret[1] == 22.0 - assert before[2] == ret[2] + self._invalid_test_helper(out, [11.0, 22.0, None]) def test_where_invalid_add_left(self): mask = Tensor.arange(3) < 2 mixed = mask.where(Tensor([10.0, 20.0, 30.0]), Invalid) out = Tensor([1.0, 2.0, 3.0]) + mixed - before, ret = self._realize_and_capture(out) - assert ret[0] == 11.0 and ret[1] == 22.0 - assert before[2] == ret[2] + self._invalid_test_helper(out, [11.0, 22.0, None]) def test_where_always_true(self): mask = Tensor.arange(3) < 10 out = mask.where(Tensor([10.0, 20.0, 30.0]), Invalid) - before, ret = self._realize_and_capture(out) - assert ret == [10.0, 20.0, 30.0] + self._invalid_test_helper(out, [10.0, 20.0, 30.0]) def test_where_cast(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).cast(dtypes.int) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1 and ret[1] == 2 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1, 2, None, None]) def test_where_compare(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid) > 1 - before, ret = self._realize_and_capture(out) - assert not ret[0] and ret[1] - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [False, True, None, None]) def test_where_unary(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 4.0, 9.0, 16.0]), Invalid).sqrt() - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_where_where(self): mask1 = Tensor.arange(4) < 2 mask2 = Tensor.arange(4) > 0 out = mask2.where(mask1.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid), Invalid) - before, ret = self._realize_and_capture(out) - assert ret[1] == 2.0 - assert before[0] == ret[0] and before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [None, 2.0, None, None]) def test_where_reduce_always_true(self): mask = Tensor.arange(4) < 9 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).sum() - before, ret = self._realize_and_capture(out) + before, ret = self._invalid_test_helper(out, []) assert ret == 10.0 def test_invalid_unary(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.float).sqrt()) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_invalid_binary(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.float) + 2) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_invalid_binary_left(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), 2 + Tensor.full((4,), Invalid, dtype=dtypes.float)) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_invalid_reshape(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).reshape(2,2) - before, ret = self._realize_and_capture(out) + before, ret = self._invalid_test_helper(out, []) assert ret[0] == [1.0, 2.0] - assert ret[1] == [before[2], before[3]] + assert ret[1][0] == before[2] or (math.isnan(ret[1][0]) and math.isnan(before[2])) + assert ret[1][1] == before[3] or (math.isnan(ret[1][1]) and math.isnan(before[3])) def test_invalid_cast(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.int).cast(dtypes.float)) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_invalid_bitcast(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.int).bitcast(dtypes.float)) - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) def test_where_bitcast(self): mask = Tensor.arange(4) < 2 out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Tensor.full((4,), Invalid, dtype=dtypes.int)).bitcast(dtypes.int) - before, ret = self._realize_and_capture(out) - assert ret[0] == 0x3f800000 and ret[1] == 0x40000000 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [0x3f800000, 0x40000000, None, None]) # tensor indexing uses reduce, so the entire result becomes invalid @unittest.expectedFailure def test_tensor_index(self): idx = (Tensor.arange(4) < 2).where(Tensor([0, 1, 2, 3]), Invalid) out = Tensor([1.0, 2.0, 3.0, 4.0])[idx] - before, ret = self._realize_and_capture(out) - assert ret[0] == 1.0 and ret[1] == 2.0 - assert before[2] == ret[2] and before[3] == ret[3] + self._invalid_test_helper(out, [1.0, 2.0, None, None]) if __name__ == '__main__': unittest.main()