mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
multireduce scheduler tests (#5141)
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -687,6 +687,81 @@ class TestSchedule(unittest.TestCase):
|
||||
out1 = out0[0] + Tensor.empty(1, )
|
||||
check_schedule([r, out0, out1], 3)
|
||||
|
||||
# multireduce spec
|
||||
def test_std_multireduce_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.std(-1)
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# multireduce spec
|
||||
def test_argmin_multireduce_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.argmin(-1)
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
|
||||
|
||||
# multireduce spec
|
||||
def test_argmax_multireduce_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 32).realize()
|
||||
out = x.argmax(-1)
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
|
||||
|
||||
# multireduce spec
|
||||
def test_scaled_dot_product_attention_multireduce_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
q = Tensor.randn(32,8,16,64).realize()
|
||||
k = Tensor.randn(32,8,16,64).realize()
|
||||
v = Tensor.randn(32,8,16,64).realize()
|
||||
out = Tensor.scaled_dot_product_attention(q,k,v)
|
||||
check_schedule(out, 5) # correctness checked in test_ops
|
||||
|
||||
# multireduce spec
|
||||
def test_ugly_reduceop_pairing(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(4, 32).realize()
|
||||
b = Tensor.randn(4, 32).realize()
|
||||
c = Tensor.randn(4, 32).realize()
|
||||
out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_allclose(out.numpy(), \
|
||||
(c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# multireduce spec
|
||||
def test_reduce_expand_reduce_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(4, 32).realize()
|
||||
out = (a+a.sum(-1, keepdim=True)).sum(-1)
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# multireduce spec
|
||||
def test_reduce_expand_reduce_expand_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(4, 32).realize()
|
||||
out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True)
|
||||
# run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_allclose(out.numpy(), \
|
||||
a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# multireduce spec
|
||||
def test_branching_reduces_and_expands_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(4, 32).realize()
|
||||
out0 = a+a.sum(-1, keepdim=True)
|
||||
out1 = out0.sum(-1)
|
||||
# run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule([out0, out1], 3))
|
||||
np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_fusion_simple_sequential(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user