mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
use operator instead of lambda in python_alu (#3590)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import functools, math
|
||||
import functools, math, operator
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten, all_same
|
||||
@@ -34,10 +34,10 @@ def hook_overflow(dv, fxn):
|
||||
python_alu = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
||||
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: lambda x: math.sin(x), UnaryOps.NEG: lambda x: -x,
|
||||
BinaryOps.MUL: lambda x,y: x*y, BinaryOps.ADD: lambda x,y: x+y, BinaryOps.SUB: lambda x,y: x-y, BinaryOps.XOR: lambda x,y: x^y,
|
||||
BinaryOps.MAX: lambda x,y: max(x, y), BinaryOps.CMPEQ: lambda x,y: x==y, BinaryOps.CMPLT: lambda x,y: x<y,
|
||||
BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan), BinaryOps.MOD: lambda x,y: x%y,
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, UnaryOps.NEG: operator.neg,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod,
|
||||
BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan),
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def exec_alu(arg, dtype, p):
|
||||
|
||||
@@ -30,8 +30,7 @@ class CStyleLanguage(NamedTuple):
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"
|
||||
}
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
|
||||
@@ -23,8 +23,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
BinaryOps.MAX: lambda builder, x, y, var_dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(var_dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(var_dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, var_dtype: builder.urem(x, y) if is_bool_or_unsigned(var_dtype) else builder.srem(x, y) if dtypes.is_int(var_dtype) else builder.frem(x, y), # noqa: E501
|
||||
BinaryOps.XOR: lambda builder, x, y, var_dtype: builder.xor(x, y),
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z),
|
||||
}
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, var_dtype: builder.select(x, y, z)}
|
||||
|
||||
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
||||
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
||||
|
||||
Reference in New Issue
Block a user