mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Vectorize Transcendental Regression Tests (#9753)
* init test * cleanup
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if
|
||||
from tinygrad.codegen.transcendental import TRANSCENDENTAL_SUPPORTED_DTYPES, payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if, xpow
|
||||
from test.helpers import eval_uop
|
||||
|
||||
class TestTranscendentalFunctions(unittest.TestCase):
|
||||
@@ -70,5 +71,26 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
|
||||
|
||||
class TestVectorizedTranscendetalFunctions(unittest.TestCase):
|
||||
def _check_all_uops_vectorized(self, u:tuple|UOp, vcount:int):
|
||||
# check all UOps in u are vectorized with vcount
|
||||
if isinstance(u, UOp): assert u.dtype.vcount == vcount, f'expected {vcount=} but got {u.dtype.vcount=} for UOp {u=}'
|
||||
[self._check_all_uops_vectorized(x, vcount) for x in (u if isinstance(u, tuple) else u.src)]
|
||||
|
||||
def _get_inputs(self) -> tuple[UOp, DType]:
|
||||
for val in [-2,1.3,194]:
|
||||
for vcount in [1,2,4,19]:
|
||||
for _dtype in TRANSCENDENTAL_SUPPORTED_DTYPES:
|
||||
dtype: DType = _dtype.vec(vcount)
|
||||
d = UOp.const(dtype, val)
|
||||
yield d, dtype
|
||||
|
||||
def test_preserves_vectorization(self):
|
||||
# verify that when given a vectorized (or scalar) input, the function returns a vectorized (or scalar) output
|
||||
for d, dtype in self._get_inputs():
|
||||
self._check_all_uops_vectorized(payne_hanek_reduction(d), dtype.vcount)
|
||||
self._check_all_uops_vectorized(cody_waite_reduction(d), dtype.vcount)
|
||||
self._check_all_uops_vectorized(xpow(d, d), dtype.vcount)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user