use operator instead of lambda in python_alu (#3590)

This commit is contained in:
chenyu
2024-03-02 19:33:21 -05:00
committed by GitHub
parent a89afd4ffa
commit ee41fafdab
3 changed files with 7 additions and 9 deletions

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import functools, math
import functools, math, operator
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, all_same
@@ -34,10 +34,10 @@ def hook_overflow(dv, fxn):
python_alu = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: lambda x: math.sin(x), UnaryOps.NEG: lambda x: -x,
BinaryOps.MUL: lambda x,y: x*y, BinaryOps.ADD: lambda x,y: x+y, BinaryOps.SUB: lambda x,y: x-y, BinaryOps.XOR: lambda x,y: x^y,
BinaryOps.MAX: lambda x,y: max(x, y), BinaryOps.CMPEQ: lambda x,y: x==y, BinaryOps.CMPLT: lambda x,y: x<y,
BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan), BinaryOps.MOD: lambda x,y: x%y,
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, UnaryOps.NEG: operator.neg,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod,
BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan),
TernaryOps.WHERE: lambda x,y,z: y if x else z}
def exec_alu(arg, dtype, p):

View File

@@ -30,8 +30,7 @@ class CStyleLanguage(NamedTuple):
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"
}
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:

View File

@@ -23,8 +23,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), # noqa: E501
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
}
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z)}
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),