diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 3d8f5ec03f..e0a55fe5b0 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -241,5 +241,58 @@ class TestSymbolicJit(unittest.TestCase): expected = a.mean(1).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_var(self): + def f(a): return a.var().realize() + def f0(a): return a.var(0).realize() + def f1(a): return a.var(1).realize() + jf = TinyJit(f) + jf0 = TinyJit(f0) + jf1 = TinyJit(f1) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + # aixs = None + a = Tensor.rand(i, 3) + symbolic = jf(a.reshape(vi, 3)).numpy() + expected = a.var().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + # aixs = 0 + a = Tensor.rand(i, 3) + symbolic = jf0(a.reshape(vi, 3)).numpy() + expected = a.var(0).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + # aixs = 1 + a = Tensor.rand(i, 3) + symbolic = jf1(a.reshape(vi, 3)).reshape(i).numpy() + expected = a.var(1).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + @unittest.skip("failed for some") + def test_var_2d(self): + def f(a): return a.var().realize() + def f0(a): return a.var(0).realize() + def f1(a): return a.var(1).realize() + jf = TinyJit(f) + jf0 = TinyJit(f0) + jf1 = TinyJit(f1) + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + # aixs = None + a = Tensor.rand(i, j) + symbolic = jf(a.reshape(vi, vj)).numpy() + expected = a.var().numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + # aixs = 0 + a = Tensor.rand(i, j) + symbolic = jf0(a.reshape(vi, vj)).reshape(j).numpy() + expected = a.var(0).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + # aixs = 1 + a = Tensor.rand(i, j) + symbolic = jf1(a.reshape(vi, vj)).reshape(i).numpy() + expected = a.var(1).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 6a4b0fbd60..d913aefb31 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -152,42 +152,42 @@ class TestSymbolicOps(unittest.TestCase): def test_mean(self): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) - # aixs = None - a = Tensor.rand(i, 3) - symbolic = a.reshape(vi, 3).mean().numpy() - expected = a.mean().numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, 3) - symbolic = a.reshape(vi, 3).mean(0).numpy() - expected = a.mean(0).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, 3) - symbolic = a.reshape(vi, 3).mean(1).reshape(i).numpy() - expected = a.mean(1).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + for axis in [None, 0, 1]: + a = Tensor.rand(i, 3) + expected = a.mean(axis).numpy() + symbolic = a.reshape(vi, 3).mean(axis).reshape(expected.shape).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_mean_2d(self): for i in range(1, 5): for j in range(1, 5): vi = Variable("i", 1, 10).bind(i) vj = Variable("j", 1, 10).bind(j) - # aixs = None - a = Tensor.rand(i, j) - symbolic = a.reshape(vi, vj).mean().numpy() - expected = a.mean().numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 0 - a = Tensor.rand(i, j) - symbolic = a.reshape(vi, vj).mean(0).reshape(j).numpy() - expected = a.mean(0).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - # aixs = 1 - a = Tensor.rand(i, j) - symbolic = a.reshape(vi, vj).mean(1).reshape(i).numpy() - expected = a.mean(1).numpy() + for axis in [None, 0, 1]: + a = Tensor.rand(i, j) + expected = a.mean(axis).numpy() + symbolic = a.reshape(vi, vj).mean(axis).reshape(expected.shape).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_var(self): + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + for axis in [None, 0, 1]: + a = Tensor.rand(i, 3) + expected = a.var(axis).numpy() + symbolic = a.reshape(vi, 3).var(axis).reshape(expected.shape).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_var_2d(self): + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + for axis in [None, 0, 1]: + a = Tensor.rand(i, j) + expected = a.var(axis).numpy() + symbolic = a.reshape(vi, vj).var(axis).reshape(expected.shape).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index 45b19928de..aab263428c 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -26,38 +26,31 @@ class TestTensorVariable(unittest.TestCase): assert (Tensor(3) * (vv * 4)).item() == 24 def test_symbolic_mean(self): - vv = Variable("a", 1, 10) - vv.bind(2) + vv = Variable("a", 1, 10).bind(2) t = Tensor.ones(2, 2).contiguous().reshape(2, vv) ret = t.mean().item() assert ret == 1 - @unittest.skip("symbolic var isn't supported") - def test_symbolic_var(self): - vv = Variable("a", 1, 10) - vv.bind(2) - t = Tensor.ones(2, 2).contiguous().reshape(2, vv) - ret = t.var().item() - assert ret == 0 - def test_symbolic_mean_2d(self): - vv = Variable("a", 1, 10) - vv.bind(2) - vv2 = Variable("b", 1, 10) - vv2.bind(2) + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) ret = t.mean().item() assert ret == 1 def test_symbolic_mean_2d_axis_1(self): - vv = Variable("a", 1, 10) - vv.bind(2) - vv2 = Variable("b", 1, 10) - vv2.bind(2) + vv = Variable("a", 1, 10).bind(2) + vv2 = Variable("b", 1, 10).bind(2) t = Tensor.ones(2, 2).contiguous().reshape(vv2, vv) ret = t.mean(axis=1).reshape(2, 1).numpy() assert np.all(ret == 1) + def test_symbolic_var(self): + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 2).contiguous().reshape(2, vv) + ret = t.var().item() + assert ret == 0 + @unittest.skip("symbolic arange isn't supported") def test_symbolic_arange(self): vv = Variable("a", 1, 10) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7d16f26a75..8651b6adcb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -13,7 +13,7 @@ from tinygrad.lazy import LazyBuffer from tinygrad.multi import MultiLazyBuffer from tinygrad.ops import LoadOps from tinygrad.device import Device, Buffer, BufferOptions -from tinygrad.shape.symbolic import sint, Variable, MulNode, Node +from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node from tinygrad.engine.realize import run_schedule from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, memory_planner @@ -316,8 +316,10 @@ class Tensor: @staticmethod def from_node(y:Node, **kwargs) -> Tensor: - if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b + if isinstance(y, NumNode): return Tensor(y.b, **kwargs, requires_grad=False) if isinstance(y, Variable): return Tensor(y, **kwargs, requires_grad=False) + if isinstance(y, MulNode): return Tensor.from_node(y.a, **kwargs) * y.b + if isinstance(y, SumNode): return Tensor.from_node(y.nodes[0], **kwargs) + sum(y.nodes[1:]) raise RuntimeError(f"unhandled Node {y}") # ***** creation llop entrypoint ***** @@ -1352,9 +1354,9 @@ class Tensor: print(t.var(axis=1).numpy()) ``` """ - assert all_int(self.shape), "does not support symbolic shape" - square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction)) + squares = (self - self.mean(axis=axis, keepdim=True)).square() + n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so]) + return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction)) def std(self, axis=None, keepdim=False, correction=1): """