From cfd28517dfd34850f35cfc69994604df7b74edaf Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 7 Feb 2025 12:51:43 -0500 Subject: [PATCH] move pow folding tests to test_schedule [pr] (#8955) not really belongs to test_const_folding --- test/test_const_folding.py | 30 +----------------------------- test/test_schedule.py | 31 +++++++++++++++++++++++++------ 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 920bf6bcda..aaf1eb8e63 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -1,6 +1,6 @@ import unittest, math from tinygrad import Tensor, Device, dtypes -from tinygrad.ops import Ops, GroupOp +from tinygrad.ops import Ops from tinygrad.helpers import CI import numpy as np from tinygrad.device import is_dtype_supported @@ -97,34 +97,6 @@ class TestBinaryOpsConstFolding(unittest.TestCase): def test_tensor_one_pow(self): _check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4])) - def test_2_pow_is_exp2(self): - t = 2.0 ** Tensor([1.0, 2.0, 3.0]) - s = [s for s in t.schedule() if s.ast.op is Ops.SINK] - self.assertEqual(len(s), 1) - alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU] - self.assertEqual(alu, [Ops.EXP2]) - - def test_pow_05_is_sqrt(self): - t = Tensor([1.0, 2.0, 3.0]) ** 0.5 - s = [s for s in t.schedule() if s.ast.op is Ops.SINK] - self.assertEqual(len(s), 1) - alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU] - self.assertEqual(alu, [Ops.SQRT]) - - def test_pow_neg_05_is_rsqrt(self): - t = Tensor([1.0, 2.0, 3.0]) ** -0.5 - s = [s for s in t.schedule() if s.ast.op is Ops.SINK] - self.assertEqual(len(s), 1) - alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU] - self.assertEqual(alu, [Ops.RECIP, Ops.SQRT]) - - def test_pow_8_has_3_muls(self): - t = Tensor([1.0, 2.0, 3.0]) ** 8 - s = [s for s in t.schedule() if s.ast.op is Ops.SINK] - self.assertEqual(len(s), 1) - alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU] - self.assertEqual(alu, [Ops.MUL, Ops.MUL, Ops.MUL]) - # folds advance indexing into basic indexing class TestIndexingConstFolding(unittest.TestCase): def test_scalar_index(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index f8f4cca6c0..315a64dfd7 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views +from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views, GroupOp from tinygrad.spec import type_verify, shape_spec from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym @@ -559,11 +559,30 @@ class TestSchedule(unittest.TestCase): out = x.to('python') check_schedule(out, 0, filter_sink=False) - def test_pow_const_tensor_simplified(self): - x = Tensor([1,2,3,4]) - # NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5) - out = x ** Tensor(2.0) - check_schedule(out, 1) + def _alu_from_tensor(self, t:Tensor): + s = [s for s in t.schedule() if s.ast.op is Ops.SINK] + self.assertEqual(len(s), 1) + return [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU] + + def test_2_pow_is_exp2(self): + t = 2.0 ** Tensor([1.0, 2.0, 3.0]) + self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2]) + + def test_pow_05_is_sqrt(self): + t = Tensor([1.0, 2.0, 3.0]) ** 0.5 + self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT]) + + def test_pow_neg_05_is_rsqrt(self): + t = Tensor([1.0, 2.0, 3.0]) ** -0.5 + self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT]) + + def test_pow_2_has_1_mul(self): + t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0) + self.assertEqual(self._alu_from_tensor(t), [Ops.MUL]) + + def test_pow_8_has_3_muls(self): + t = Tensor([1.0, 2.0, 3.0]) ** 8 + self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL]) def test_pow_const_tensor_to_zero(self): x = Tensor([1,2,3,4])