mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
do not patch on invalid tensor tests (#15226)
* do not patch on invalid tensor tests * cleanup
This commit is contained in:
@@ -1,27 +1,20 @@
|
||||
import math, unittest
|
||||
from unittest.mock import patch
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import Invalid, dtypes
|
||||
from tinygrad.helpers import unwrap_class_type
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
class TestInvalidTensor(unittest.TestCase):
|
||||
def _invalid_test_helper(self, out, expected):
|
||||
before = None
|
||||
original_call = (runtime_cls:=unwrap_class_type(Device[Device.DEFAULT].runtime)).__call__
|
||||
sched = out.schedule()
|
||||
buf = out.uop.buffer
|
||||
buf.allocate()
|
||||
sentinel = memoryview(bytearray(b'\x42' * buf.nbytes))
|
||||
buf.copyin(sentinel)
|
||||
before = buf.as_memoryview().cast(out.dtype.fmt).tolist()
|
||||
run_schedule(sched)
|
||||
ret = buf.as_memoryview().cast(out.dtype.fmt).tolist()
|
||||
|
||||
def patched_call(self_prg, *bufs, **kwargs):
|
||||
nonlocal before
|
||||
before = Device[Device.DEFAULT].allocator._as_buffer(bufs[0]).cast(out.dtype.fmt).tolist()
|
||||
return original_call(self_prg, *bufs, **kwargs)
|
||||
|
||||
with patch.object(runtime_cls, '__call__', patched_call): ret = out.tolist()
|
||||
|
||||
for i,v in enumerate(expected):
|
||||
if v is None: assert before[i] == ret[i] or (math.isnan(before[i]) and math.isnan(ret[i]))
|
||||
else: assert ret[i] == v
|
||||
|
||||
return before, ret
|
||||
for i,v in enumerate(expected): self.assertEqual(ret[i], before[i] if v is None else v)
|
||||
|
||||
def test_where_x_invalid(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
@@ -37,11 +30,7 @@ class TestInvalidTensor(unittest.TestCase):
|
||||
mask = Tensor.arange(6).reshape(2, 3) < 3
|
||||
vals = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
out = mask.where(vals, Invalid)
|
||||
before, ret = self._invalid_test_helper(out, [])
|
||||
assert ret[0] == [1.0, 2.0, 3.0]
|
||||
assert before[3] == ret[1][0] or (math.isnan(before[3]) and math.isnan(ret[1][0]))
|
||||
assert before[4] == ret[1][1] or (math.isnan(before[4]) and math.isnan(ret[1][1]))
|
||||
assert before[5] == ret[1][2] or (math.isnan(before[5]) and math.isnan(ret[1][2]))
|
||||
self._invalid_test_helper(out, [1.0, 2.0, 3.0, None, None, None])
|
||||
|
||||
def test_where_invalid_int(self):
|
||||
mask = Tensor.arange(3) < 2
|
||||
@@ -89,8 +78,7 @@ class TestInvalidTensor(unittest.TestCase):
|
||||
def test_where_reduce_always_true(self):
|
||||
mask = Tensor.arange(4) < 9
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).sum()
|
||||
before, ret = self._invalid_test_helper(out, [])
|
||||
assert ret == 10.0
|
||||
self._invalid_test_helper(out, [10.0])
|
||||
|
||||
def test_invalid_unary(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
@@ -110,10 +98,7 @@ class TestInvalidTensor(unittest.TestCase):
|
||||
def test_invalid_reshape(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
out = mask.where(Tensor([1.0, 2.0, 3.0, 4.0]), Invalid).reshape(2,2)
|
||||
before, ret = self._invalid_test_helper(out, [])
|
||||
assert ret[0] == [1.0, 2.0]
|
||||
assert ret[1][0] == before[2] or (math.isnan(ret[1][0]) and math.isnan(before[2]))
|
||||
assert ret[1][1] == before[3] or (math.isnan(ret[1][1]) and math.isnan(before[3]))
|
||||
self._invalid_test_helper(out, [1.0, 2.0, None, None])
|
||||
|
||||
def test_invalid_cast(self):
|
||||
mask = Tensor.arange(4) < 2
|
||||
|
||||
Reference in New Issue
Block a user