hotfix: compare test_var_multireduce against numpy (#5394)

This commit is contained in:
qazal
2024-07-12 06:57:08 +08:00
committed by GitHub
parent b91a0ccdc3
commit 0421f5d83e

View File

@@ -117,10 +117,11 @@ class TestLinearizer(unittest.TestCase):
squares = (second_x-mean)*(second_x-mean)
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
helper_linearizer_ast((store, ), [x])
wanna_output = x.numpy().var(axis=2, ddof=0)
helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
# tinygrad ref
y_tiny = x.var(axis=2, correction=0)
np.testing.assert_allclose(y_tiny.numpy(), x.numpy().var(axis=2, ddof=0), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")