diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index be2e90645b..27a8790e47 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import functools, math +import functools, math, operator from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast from collections import defaultdict from tinygrad.helpers import DEBUG, flatten, all_same @@ -34,10 +34,10 @@ def hook_overflow(dv, fxn): python_alu = { UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))), - UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: lambda x: math.sin(x), UnaryOps.NEG: lambda x: -x, - BinaryOps.MUL: lambda x,y: x*y, BinaryOps.ADD: lambda x,y: x+y, BinaryOps.SUB: lambda x,y: x-y, BinaryOps.XOR: lambda x,y: x^y, - BinaryOps.MAX: lambda x,y: max(x, y), BinaryOps.CMPEQ: lambda x,y: x==y, BinaryOps.CMPLT: lambda x,y: x= 0 else math.nan, UnaryOps.SIN: math.sin, UnaryOps.NEG: operator.neg, + BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor, + BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod, + BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan), TernaryOps.WHERE: lambda x,y,z: y if x else z} def exec_alu(arg, dtype, p): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 4b875cd17d..6f286b6882 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -30,8 +30,7 @@ class CStyleLanguage(NamedTuple): BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", - TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" - } + TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"} # returns a str expression of the casted xs with the given type def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index c4eaaf0793..174e95ca92 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -23,8 +23,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), # noqa: E501 BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y), - TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z), -} + TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z)} dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),