mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add challenge tests
This commit is contained in:
51
test/external/external_test_opt.py
vendored
51
test/external/external_test_opt.py
vendored
@@ -174,5 +174,56 @@ class TestOpt(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
@unittest.skip("this is broken")
|
||||
def test_no_binop_rerun(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
c = a*b
|
||||
d = (a*b).reshape(16, 16, 1)
|
||||
c.realize()
|
||||
d.realize()
|
||||
assert len(GlobalCounters.cache) == 1, "binop was rerun!"
|
||||
np.testing.assert_allclose(c.numpy(), d.numpy(), rtol=1e-3)
|
||||
|
||||
@unittest.skip("this is broken")
|
||||
def test_no_binop_rerun_alt(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
with CLCache():
|
||||
c = (a*b).reshape(16, 16, 1)
|
||||
d = a*b
|
||||
c.realize()
|
||||
d.realize()
|
||||
assert len(GlobalCounters.cache) == 1, "binop was rerun!"
|
||||
np.testing.assert_allclose(c.numpy(), d.numpy(), rtol=1e-3)
|
||||
|
||||
# TODO: should be okay with PUSH_PERMUTES
|
||||
def test_no_reduceop_rerun(self):
|
||||
if PUSH_PERMUTES: return
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
c = a.sum(2)
|
||||
d = a.sum(2).permute(1,0)
|
||||
c.realize()
|
||||
d.realize()
|
||||
cache_len = len(GlobalCounters.cache)
|
||||
np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy())
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
# TODO: should be okay with PUSH_PERMUTES
|
||||
def test_no_reduceop_rerun_alt(self):
|
||||
if PUSH_PERMUTES: return
|
||||
a = Tensor.randn(16, 16, 16)
|
||||
with CLCache():
|
||||
c = a.sum(2).permute(1,0)
|
||||
d = a.sum(2)
|
||||
c.realize()
|
||||
d.realize()
|
||||
cache_len = len(GlobalCounters.cache)
|
||||
np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0))
|
||||
assert cache_len == 1, "reduceop was rerun!"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user