From 0e02d074bd83f138d5a34b1daf5cb2d835bba04f Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 31 Mar 2024 19:57:23 -0400 Subject: [PATCH] fix Tensor.pow folding for exponent 0 and 1 (#4025) --- test/test_const_folding.py | 5 ----- test/test_schedule.py | 2 +- tinygrad/tensor.py | 3 ++- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index d18f5b2f73..9f3c17faaf 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -34,18 +34,13 @@ class TestSimpleConstFolding(unittest.TestCase): def test_div_tensor_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4)) - # TODO: fix pow const folding - @unittest.expectedFailure def test_pow_literal_zero(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0) - @unittest.expectedFailure def test_pow_tensor_zero(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.zeros(4)) - @unittest.expectedFailure def test_pow_literal_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1) - @unittest.expectedFailure def test_pow_tensor_one(self): _check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4)) diff --git a/test/test_schedule.py b/test/test_schedule.py index b86d5c98ac..31f1965d0f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -363,7 +363,7 @@ class TestSchedule(unittest.TestCase): x = Tensor([1,2,3,4]) out = x ** Tensor(0) # NOTE: this is ConstBuffer 0 + ConstBuffer 1 - check_schedule(out, 1) + check_schedule(out, 0) def test_zero_size(self): x = Tensor.empty(2, 3, 0) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9e1df7f99a..cbd31eb9a6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -902,7 +902,8 @@ class Tensor: if not isinstance(x, Tensor) and not reverse: # simple pow identities if x < 0: return self.reciprocal().pow(-x) - if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), self*0+1) + if x == 0: return 1 + self * 0 + if x in [3,2,1]: return functools.reduce(lambda acc,_: acc * self, range(int(x)-1), self) if x == 0.5: return self.sqrt() if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()