mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Vectorize Transcendental Regression Tests (#9753)
* init test * cleanup
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
import unittest, math
|
import unittest, math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import dtypes
|
from tinygrad import dtypes
|
||||||
|
from tinygrad.dtype import DType
|
||||||
from tinygrad.ops import UOp, Ops
|
from tinygrad.ops import UOp, Ops
|
||||||
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if
|
from tinygrad.codegen.transcendental import TRANSCENDENTAL_SUPPORTED_DTYPES, payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if, xpow
|
||||||
from test.helpers import eval_uop
|
from test.helpers import eval_uop
|
||||||
|
|
||||||
class TestTranscendentalFunctions(unittest.TestCase):
|
class TestTranscendentalFunctions(unittest.TestCase):
|
||||||
@@ -70,5 +71,26 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
|
||||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
|
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
|
||||||
|
|
||||||
|
class TestVectorizedTranscendetalFunctions(unittest.TestCase):
|
||||||
|
def _check_all_uops_vectorized(self, u:tuple|UOp, vcount:int):
|
||||||
|
# check all UOps in u are vectorized with vcount
|
||||||
|
if isinstance(u, UOp): assert u.dtype.vcount == vcount, f'expected {vcount=} but got {u.dtype.vcount=} for UOp {u=}'
|
||||||
|
[self._check_all_uops_vectorized(x, vcount) for x in (u if isinstance(u, tuple) else u.src)]
|
||||||
|
|
||||||
|
def _get_inputs(self) -> tuple[UOp, DType]:
|
||||||
|
for val in [-2,1.3,194]:
|
||||||
|
for vcount in [1,2,4,19]:
|
||||||
|
for _dtype in TRANSCENDENTAL_SUPPORTED_DTYPES:
|
||||||
|
dtype: DType = _dtype.vec(vcount)
|
||||||
|
d = UOp.const(dtype, val)
|
||||||
|
yield d, dtype
|
||||||
|
|
||||||
|
def test_preserves_vectorization(self):
|
||||||
|
# verify that when given a vectorized (or scalar) input, the function returns a vectorized (or scalar) output
|
||||||
|
for d, dtype in self._get_inputs():
|
||||||
|
self._check_all_uops_vectorized(payne_hanek_reduction(d), dtype.vcount)
|
||||||
|
self._check_all_uops_vectorized(cody_waite_reduction(d), dtype.vcount)
|
||||||
|
self._check_all_uops_vectorized(xpow(d, d), dtype.vcount)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user