mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
move exec_alu from uops to ops (#4033)
will use this for const folding in lazy too
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import functools, math, operator, itertools, ctypes
|
||||
import functools, itertools
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten, prod
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
|
||||
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
@@ -27,33 +27,6 @@ class UOp:
|
||||
@staticmethod
|
||||
def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
|
||||
|
||||
def hook_overflow(dv, fxn):
|
||||
def wfxn(*args):
|
||||
try: return fxn(*args)
|
||||
except OverflowError: return dv
|
||||
return wfxn
|
||||
|
||||
python_alu = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
||||
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
|
||||
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
|
||||
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
|
||||
# TODO: float16 and bfloat16?
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
||||
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
|
||||
|
||||
def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p))
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop is UOps.CONST: return u.arg
|
||||
elif u.uop is UOps.DEFINE_VAR: return u.arg
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Union, Type, Tuple, Any, List, Dict, Callable
|
||||
import functools, hashlib
|
||||
import functools, hashlib, math, operator, ctypes
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, dedup
|
||||
@@ -96,3 +96,32 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
|
||||
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
||||
return run_ast(ast)
|
||||
|
||||
# **************** ops in python ****************
|
||||
|
||||
def hook_overflow(dv, fxn):
|
||||
def wfxn(*args):
|
||||
try: return fxn(*args)
|
||||
except OverflowError: return dv
|
||||
return wfxn
|
||||
|
||||
python_alu = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
||||
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
|
||||
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
|
||||
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
|
||||
# TODO: float16 and bfloat16?
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
||||
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
|
||||
|
||||
def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p))
|
||||
@@ -6,8 +6,8 @@ import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, CompilerOptions
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
|
||||
def _load(m, i):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
|
||||
Reference in New Issue
Block a user