mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 00:55:11 -05:00
add constant folding for WHERE in uops (#3584)
* add constant folding for WHERE in uops * prereqs for generic constant folding * fix test * disable slow overflow logic * make that test faster
This commit is contained in:
@@ -2,36 +2,14 @@
|
||||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, math, struct
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Allocator, Compiler
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOp, UOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
def exec_alu(arg, dtype, p):
|
||||
# TODO: make this complete and correctly honor the dtypes
|
||||
# TODO: use this for constant folding
|
||||
if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2]
|
||||
if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan
|
||||
if arg == UnaryOps.EXP2:
|
||||
try: return math.exp(p[0]*math.log(2))
|
||||
except OverflowError: return math.inf
|
||||
if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan
|
||||
if arg == UnaryOps.SIN: return math.sin(p[0])
|
||||
if arg == UnaryOps.NEG: return -p[0]
|
||||
if arg == BinaryOps.MUL: return p[0]*p[1]
|
||||
if arg == BinaryOps.ADD: return p[0]+p[1]
|
||||
if arg == BinaryOps.SUB: return p[0]-p[1]
|
||||
if arg == BinaryOps.XOR: return p[0]^p[1]
|
||||
if arg == BinaryOps.MAX: return max(p[0], p[1])
|
||||
if arg == BinaryOps.CMPEQ: return p[0] == p[1]
|
||||
if arg == BinaryOps.CMPLT: return p[0] < p[1]
|
||||
if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan)
|
||||
if arg == BinaryOps.MOD: return p[0]%p[1]
|
||||
raise NotImplementedError(f"no support for {arg}")
|
||||
|
||||
def _load(m, i):
|
||||
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
return m[i]
|
||||
|
||||
Reference in New Issue
Block a user