mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simple failing test for scheduling parallel reduce [pr] (#9501)
* simple failing test for scheduling parallel reduce [pr] * atol
This commit is contained in:
@@ -1958,6 +1958,24 @@ class TestSwizzle(unittest.TestCase):
|
||||
t = a_reduce+b_reduce
|
||||
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
|
||||
|
||||
def test_parallel_reduce_possible(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 2, 2).realize()
|
||||
y = Tensor.randn(4, 2, 2).realize()
|
||||
t = x.sum(axis=1)+y.sum(axis=1)
|
||||
with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(t, 1))
|
||||
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
|
||||
|
||||
# kernels can only have 1 or n in each dim
|
||||
@unittest.expectedFailure
|
||||
def test_dont_parallelize_different_n(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 2, 2).realize()
|
||||
y = Tensor.randn(4, 3, 2).realize()
|
||||
t = x.sum(axis=1)+y.sum(axis=1)
|
||||
with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(t, 1))
|
||||
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
|
||||
|
||||
def test_unsafe_pad(self):
|
||||
x = Tensor.full((2,2), 1.0).contiguous()
|
||||
y = x*x.sum((1,)).reciprocal()
|
||||
|
||||
Reference in New Issue
Block a user