diff --git a/test/test_uops.py b/test/test_uops.py index 2910b5a664..31bce05bb9 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,6 +7,7 @@ from tinygrad.device import Buffer, Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.device import CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp +from tinygrad.runtime.ops_python import exec_alu from test.test_dtype import is_dtype_supported def _uops_to_prg(uops): @@ -109,5 +110,9 @@ class TestNonFloatUOps(TestUOps): def test_where_float16(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float16), PtrDType(dtypes.float16))) +class TestExecALU(TestUOps): + def test_sqrt(self): + self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 540e3c792d..429879eed4 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -19,7 +19,7 @@ def exec_alu(arg, dtype, p): if arg == UnaryOps.EXP2: try: return math.exp(p[0]*math.log(2)) except OverflowError: return math.inf - if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] > 0 else math.nan + if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan if arg == UnaryOps.SIN: return math.sin(p[0]) if arg == UnaryOps.NEG: return -p[0] if arg == BinaryOps.MUL: return p[0]*p[1] @@ -214,4 +214,4 @@ class PythonAllocator(Allocator): class PythonDevice(Compiled): def __init__(self, device:str): - super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram) \ No newline at end of file + super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram)