simple python ALU (#3589)

* shorter

* bugfix
This commit is contained in:
George Hotz
2024-03-02 15:50:58 -08:00
committed by GitHub
parent 162dfb07d9
commit 74c9acddb0

View File

@@ -25,26 +25,23 @@ class UOp:
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
def hook_overflow(dv, fxn):
def wfxn(*args):
try: return fxn(*args)
except OverflowError: return dv
return wfxn
python_alu = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: lambda x: math.sin(x), UnaryOps.NEG: lambda x: -x,
BinaryOps.MUL: lambda x,y: x*y, BinaryOps.ADD: lambda x,y: x+y, BinaryOps.SUB: lambda x,y: x-y, BinaryOps.XOR: lambda x,y: x^y,
BinaryOps.MAX: lambda x,y: max(x, y), BinaryOps.CMPEQ: lambda x,y: x==y, BinaryOps.CMPLT: lambda x,y: x<y,
BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan), BinaryOps.MOD: lambda x,y: x%y,
TernaryOps.WHERE: lambda x,y,z: y if x else z}
def exec_alu(arg, dtype, p):
if arg == TernaryOps.WHERE: ret = p[1] if p[0] else p[2]
elif arg == UnaryOps.LOG2: ret = math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
elif arg == UnaryOps.EXP2:
try: ret = math.exp(p[0]*math.log(2))
except OverflowError: ret = math.inf
elif arg == UnaryOps.SQRT: ret = math.sqrt(p[0]) if p[0] >= 0 else math.nan
elif arg == UnaryOps.SIN: ret = math.sin(p[0])
elif arg == UnaryOps.NEG: ret = -p[0]
elif arg == BinaryOps.MUL: ret = p[0]*p[1]
elif arg == BinaryOps.ADD: ret = p[0]+p[1]
elif arg == BinaryOps.SUB: ret = p[0]-p[1]
elif arg == BinaryOps.XOR: ret = p[0]^p[1]
elif arg == BinaryOps.MAX: ret = max(p[0], p[1])
elif arg == BinaryOps.CMPEQ: ret = p[0] == p[1]
elif arg == BinaryOps.CMPLT: ret = p[0] < p[1]
elif arg == BinaryOps.DIV: ret = p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
elif arg == BinaryOps.MOD: ret = p[0]%p[1]
return ret
#else: raise NotImplementedError(f"no support for {arg}")
return python_alu[arg](*p)
#if not dtypes.is_int(dtype): return ret
#adjusted = 0 if dtypes.is_unsigned(dtype) else 2 ** (dtype.itemsize * 8 - 1)
#return (ret + adjusted) % 2 ** (dtype.itemsize * 8) - adjusted