move exec_alu from uops to ops (#4033)

will use this for const folding in lazy too
This commit is contained in:
chenyu
2024-04-01 20:20:53 -04:00
committed by GitHub
parent 82440d3416
commit 0a34d6016b
3 changed files with 34 additions and 32 deletions

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
import functools, math, operator, itertools, ctypes
import functools, itertools
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, prod
from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
from enum import Enum, auto
from dataclasses import dataclass
@@ -27,33 +27,6 @@ class UOp:
@staticmethod
def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
def hook_overflow(dv, fxn):
def wfxn(*args):
try: return fxn(*args)
except OverflowError: return dv
return wfxn
python_alu = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
TernaryOps.WHERE: lambda x,y,z: y if x else z}
truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
# TODO: float16 and bfloat16?
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p))
def uop_alu_resolve(u:UOp) -> sint:
if u.uop is UOps.CONST: return u.arg
elif u.uop is UOps.DEFINE_VAR: return u.arg

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Union, Type, Tuple, Any, List, Dict, Callable
import functools, hashlib
import functools, hashlib, math, operator, ctypes
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import prod, dedup
@@ -96,3 +96,32 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** ops in python ****************
def hook_overflow(dv, fxn):
def wfxn(*args):
try: return fxn(*args)
except OverflowError: return dv
return wfxn
python_alu = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
TernaryOps.WHERE: lambda x,y,z: y if x else z}
truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
# TODO: float16 and bfloat16?
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p))

View File

@@ -6,8 +6,8 @@ import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Allocator, Compiler, CompilerOptions
from tinygrad.codegen.uops import UOpGraph, UOps, exec_alu
from tinygrad.ops import BinaryOps, TernaryOps
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
def _load(m, i):
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")