mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete irrelevant JIT regression test (#4024)
This commit is contained in:
@@ -237,62 +237,6 @@ class TestJit(unittest.TestCase):
|
||||
[0., 2., 3., 1., 0.]]
|
||||
np.testing.assert_allclose(want, Y)
|
||||
|
||||
@unittest.skip("was this supposed to work?")
|
||||
def test_jitted_read_assign(self):
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self.good_cache = Tensor.zeros(1)
|
||||
self.bad_cache = Tensor.zeros(1)
|
||||
self.good_jitted = TinyJit(self.good)
|
||||
self.bad_jitted = TinyJit(self.bad)
|
||||
|
||||
def good(self, y, cache_v=None):
|
||||
if cache_v is not None:
|
||||
self.good_cache.assign(cache_v+1-1).realize()
|
||||
return (self.good_cache + y).realize() # need + y to provide inputs to JIT
|
||||
|
||||
def bad(self, y, cache_v=None):
|
||||
if cache_v is not None:
|
||||
self.bad_cache.assign(cache_v).realize()
|
||||
return (self.bad_cache + y).realize()
|
||||
|
||||
cache = Cache()
|
||||
np.testing.assert_equal([0], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([0], cache.bad_cache.numpy())
|
||||
|
||||
zero = Tensor([0.])
|
||||
one = Tensor([1.])
|
||||
two = Tensor([2.])
|
||||
|
||||
# save [1] in the caches
|
||||
cache.good(zero, one)
|
||||
cache.bad(zero, one)
|
||||
np.testing.assert_equal([1], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([1], cache.bad_cache.numpy())
|
||||
|
||||
for i in range(5):
|
||||
x = Tensor([i*1.]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
|
||||
cache.good_jitted(x)
|
||||
cache.bad_jitted(x)
|
||||
|
||||
# verify the jitted calls read 1 from the cache
|
||||
np.testing.assert_equal([1], cache.good_jitted(zero).numpy())
|
||||
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
|
||||
|
||||
# save [2] in the caches
|
||||
cache.good(zero, two)
|
||||
cache.bad(zero, two)
|
||||
np.testing.assert_equal([2], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([2], cache.bad_cache.numpy())
|
||||
|
||||
# verify the jitted calls read 2 from the cache
|
||||
np.testing.assert_equal([2], cache.good_jitted(zero).numpy())
|
||||
# but the bad_jitted doesn't!
|
||||
np.testing.assert_equal([1], cache.bad_jitted(zero).numpy())
|
||||
|
||||
assert_jit_cache_len(cache.good_jitted, 1)
|
||||
assert_jit_cache_len(cache.bad_jitted, 1)
|
||||
|
||||
def test_jit_buffer_behavior(self):
|
||||
@TinyJit
|
||||
def foo(x) -> Tensor: return x.sum().realize()
|
||||
|
||||
Reference in New Issue
Block a user