fix exec_alu(UnaryOps.SQRT, <...>, (0,)) + add test (#3487)

* fix exec_alu(UnaryOps.SQRT, <...>, (0,)) + add test

* sqrt(0) != nan

* fix tabs
This commit is contained in:
Carson Radtke
2024-02-23 11:28:00 -06:00
committed by GitHub
parent 52567da07f
commit 15df9406d6
2 changed files with 7 additions and 2 deletions

View File

@@ -19,7 +19,7 @@ def exec_alu(arg, dtype, p):
if arg == UnaryOps.EXP2:
try: return math.exp(p[0]*math.log(2))
except OverflowError: return math.inf
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] > 0 else math.nan
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
if arg == UnaryOps.SIN: return math.sin(p[0])
if arg == UnaryOps.NEG: return -p[0]
if arg == BinaryOps.MUL: return p[0]*p[1]
@@ -214,4 +214,4 @@ class PythonAllocator(Allocator):
class PythonDevice(Compiled):
def __init__(self, device:str):
super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram)
super().__init__(device, PythonAllocator(), PythonCompiler(), PythonProgram)