delete irrelevant JIT regression test (#4024)

This commit is contained in:
mmmkkaaayy
2024-03-31 16:35:35 -07:00
committed by GitHub
parent 23c912e338
commit a4ae9352bd

View File

@@ -237,62 +237,6 @@ class TestJit(unittest.TestCase):
[0., 2., 3., 1., 0.]]
np.testing.assert_allclose(want, Y)
@unittest.skip("was this supposed to work?")
def test_jitted_read_assign(self):
class Cache:
def __init__(self):
self.good_cache = Tensor.zeros(1)
self.bad_cache = Tensor.zeros(1)
self.good_jitted = TinyJit(self.good)
self.bad_jitted = TinyJit(self.bad)
def good(self, y, cache_v=None):
if cache_v is not None:
self.good_cache.assign(cache_v+1-1).realize()
return (self.good_cache + y).realize() # need + y to provide inputs to JIT
def bad(self, y, cache_v=None):
if cache_v is not None:
self.bad_cache.assign(cache_v).realize()
return (self.bad_cache + y).realize()
cache = Cache()
np.testing.assert_equal([0], cache.good_cache.numpy())
np.testing.assert_equal([0], cache.bad_cache.numpy())
zero = Tensor([0.])
one = Tensor([1.])
two = Tensor([2.])
# save [1] in the caches
cache.good(zero, one)
cache.bad(zero, one)
np.testing.assert_equal([1], cache.good_cache.numpy())
np.testing.assert_equal([1], cache.bad_cache.numpy())
for i in range(5):
x = Tensor([i*1.]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
cache.good_jitted(x)
cache.bad_jitted(x)
# verify the jitted calls read 1 from the cache
np.testing.assert_equal([1], cache.good_jitted(zero).numpy())
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
# save [2] in the caches
cache.good(zero, two)
cache.bad(zero, two)
np.testing.assert_equal([2], cache.good_cache.numpy())
np.testing.assert_equal([2], cache.bad_cache.numpy())
# verify the jitted calls read 2 from the cache
np.testing.assert_equal([2], cache.good_jitted(zero).numpy())
# but the bad_jitted doesn't!
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
assert_jit_cache_len(cache.good_jitted, 1)
assert_jit_cache_len(cache.bad_jitted, 1)
def test_jit_buffer_behavior(self):
@TinyJit
def foo(x) -> Tensor: return x.sum().realize()