multi reduce Tensor.var passing verify_lazyop (#5346)

* what about this

* reset late gate
This commit is contained in:
qazal
2024-07-09 17:20:17 +03:00
committed by GitHub
parent 3d452195e4
commit 1f5de80eba
4 changed files with 28 additions and 4 deletions

View File

@@ -101,6 +101,25 @@ class TestLinearizer(unittest.TestCase):
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
@unittest.skipIf(CI and Device.DEFAULT == "AMD", "remu doesn't have multiple wave syncs yet")
def test_var_multireduce(self):
Tensor.manual_seed(0)
x = Tensor.randn(3, 27, 32).realize()
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1)).expand((3, 27, 32, 32))))
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501
# store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1))))
# verify_lazyop(store)
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1))))
squares = (second_x-mean)*(second_x-mean)
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
helper_linearizer_ast((store, ), [x])
# tinygrad ref
y_tiny = x.var(axis=2, correction=0)
np.testing.assert_allclose(y_tiny.numpy(), x.numpy().var(axis=2, ddof=0), atol=1e-4, rtol=1e-4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_end_local(self):