Vectorize Transcendental Regression Tests (#9753)

* init test

* cleanup
This commit is contained in:
Eitan Turok
2025-04-07 20:27:39 +03:00
committed by GitHub
parent 7c4a739fe4
commit bb7922b95f

View File

@@ -1,8 +1,9 @@
import unittest, math import unittest, math
import numpy as np import numpy as np
from tinygrad import dtypes from tinygrad import dtypes
from tinygrad.dtype import DType
from tinygrad.ops import UOp, Ops from tinygrad.ops import UOp, Ops
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if from tinygrad.codegen.transcendental import TRANSCENDENTAL_SUPPORTED_DTYPES, payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if, xpow
from test.helpers import eval_uop from test.helpers import eval_uop
class TestTranscendentalFunctions(unittest.TestCase): class TestTranscendentalFunctions(unittest.TestCase):
@@ -70,5 +71,26 @@ class TestTranscendentalFunctions(unittest.TestCase):
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10) np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63) np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
class TestVectorizedTranscendetalFunctions(unittest.TestCase):
def _check_all_uops_vectorized(self, u:tuple|UOp, vcount:int):
# check all UOps in u are vectorized with vcount
if isinstance(u, UOp): assert u.dtype.vcount == vcount, f'expected {vcount=} but got {u.dtype.vcount=} for UOp {u=}'
[self._check_all_uops_vectorized(x, vcount) for x in (u if isinstance(u, tuple) else u.src)]
def _get_inputs(self) -> tuple[UOp, DType]:
for val in [-2,1.3,194]:
for vcount in [1,2,4,19]:
for _dtype in TRANSCENDENTAL_SUPPORTED_DTYPES:
dtype: DType = _dtype.vec(vcount)
d = UOp.const(dtype, val)
yield d, dtype
def test_preserves_vectorization(self):
# verify that when given a vectorized (or scalar) input, the function returns a vectorized (or scalar) output
for d, dtype in self._get_inputs():
self._check_all_uops_vectorized(payne_hanek_reduction(d), dtype.vcount)
self._check_all_uops_vectorized(cody_waite_reduction(d), dtype.vcount)
self._check_all_uops_vectorized(xpow(d, d), dtype.vcount)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()