diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index d1da750417..285ea0eeb0 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,10 +1,10 @@ from __future__ import annotations -import functools, math, operator, itertools, ctypes +import functools, itertools from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast from collections import defaultdict from tinygrad.helpers import DEBUG, flatten, prod from tinygrad.dtype import dtypes, DType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode from enum import Enum, auto from dataclasses import dataclass @@ -27,33 +27,6 @@ class UOp: @staticmethod def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype)) -def hook_overflow(dv, fxn): - def wfxn(*args): - try: return fxn(*args) - except OverflowError: return dv - return wfxn - -python_alu = { - UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, - UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))), - UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, - UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x, - BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor, - BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, - BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], - BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf), - TernaryOps.WHERE: lambda x,y,z: y if x else z} - -truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, - # TODO: float16 and bfloat16? - dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value, - dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, - dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, - dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, - dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,} - -def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p)) - def uop_alu_resolve(u:UOp) -> sint: if u.uop is UOps.CONST: return u.arg elif u.uop is UOps.DEFINE_VAR: return u.arg diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8baab55cfa..d1090ce86a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Union, Type, Tuple, Any, List, Dict, Callable -import functools, hashlib +import functools, hashlib, math, operator, ctypes from enum import Enum, auto from dataclasses import dataclass from tinygrad.helpers import prod, dedup @@ -96,3 +96,32 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else []))) return run_ast(ast) + +# **************** ops in python **************** + +def hook_overflow(dv, fxn): + def wfxn(*args): + try: return fxn(*args) + except OverflowError: return dv + return wfxn + +python_alu = { + UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, + UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))), + UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, + UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x, + BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor, + BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, + BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], + BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf), + TernaryOps.WHERE: lambda x,y,z: y if x else z} + +truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, + # TODO: float16 and bfloat16? + dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value, + dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, + dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, + dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, + dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,} + +def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p)) \ No newline at end of file diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 1febe89117..45b25dcb8b 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -6,8 +6,8 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Allocator, Compiler, CompilerOptions -from tinygrad.codegen.uops import UOpGraph, UOps, exec_alu -from tinygrad.ops import BinaryOps, TernaryOps +from tinygrad.codegen.uops import UOpGraph, UOps +from tinygrad.ops import BinaryOps, TernaryOps, exec_alu def _load(m, i): if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")