mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
move to mathtraits.py (#10742)
This commit is contained in:
@@ -1 +1,99 @@
|
||||
from tinygrad.uop.ops import UOp, Ops # noqa: F401
|
||||
from enum import auto, IntEnum, Enum
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
class FastEnum(IntEnum):
|
||||
def __str__(self): return Enum.__str__(self)
|
||||
@staticmethod
|
||||
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
||||
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
|
||||
|
||||
# MetaOps
|
||||
COPY = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
|
||||
# blocks in linearizer
|
||||
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
|
||||
|
||||
# movement ops!
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
|
||||
# misc ops
|
||||
UNROLL = auto(); CONTRACT = auto() # noqa: E702
|
||||
VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
|
||||
DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
||||
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
||||
|
||||
# helper ops
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
|
||||
# load/store before math
|
||||
LOAD = auto(); STORE = auto() # noqa: E702
|
||||
|
||||
# early INDEX
|
||||
INDEX = auto()
|
||||
|
||||
# math ops
|
||||
WMMA = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702
|
||||
XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
# assignment ops
|
||||
ASSIGN = auto()
|
||||
BIND = auto()
|
||||
|
||||
# control flow ops
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto(); GBARRIER = auto() # noqa: E702
|
||||
|
||||
# consts last!
|
||||
VCONST = auto(); CONST = auto() # noqa: E702
|
||||
|
||||
# device
|
||||
DEVICE = auto()
|
||||
MULTI = auto()
|
||||
|
||||
# CUSTOMI is inline
|
||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||
IGNORE = auto(); FUSE = auto() # noqa: E702
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
||||
Ops.SUB, Ops.FDIV, Ops.POW}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
||||
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
||||
|
||||
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
|
||||
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
|
||||
|
||||
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
||||
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
||||
|
||||
Meta = {Ops.COPY, Ops.BUFFER_VIEW}
|
||||
|
||||
All = set(Ops)
|
||||
|
||||
78
tinygrad/uop/mathtraits.py
Normal file
78
tinygrad/uop/mathtraits.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.helpers import T
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
class MathTrait:
|
||||
# required to implement
|
||||
def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
|
||||
def const_like(self:T, b) -> T: raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
||||
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
||||
def logical_not(self): return self.ne(True)
|
||||
def neg(self):
|
||||
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
||||
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
||||
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
||||
def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
|
||||
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
||||
def bitwise_xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
||||
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
||||
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
||||
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
||||
|
||||
def __neg__(self): return self.neg()
|
||||
|
||||
def __add__(self, x): return self.add(x)
|
||||
def __sub__(self, x): return self.sub(x)
|
||||
def __mul__(self, x): return self.mul(x)
|
||||
def __truediv__(self, x): return self.div(x)
|
||||
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
||||
def __mod__(self, x): return self.mod(x)
|
||||
def __and__(self, x): return self.bitwise_and(x)
|
||||
def __or__(self, x): return self.bitwise_or(x)
|
||||
def __xor__(self, x): return self.bitwise_xor(x)
|
||||
|
||||
def __radd__(self, x): return self.add(x, True)
|
||||
def __rsub__(self, x): return self.sub(x, True)
|
||||
def __rmul__(self, x): return self.mul(x, True)
|
||||
def __rtruediv__(self, x): return self.div(x, True)
|
||||
def __rfloordiv__(self, x): return self.idiv(x, True)
|
||||
def __rand__(self, x): return self.bitwise_and(x, True)
|
||||
def __ror__(self, x): return self.bitwise_or(x, True)
|
||||
def __rxor__(self, x): return self.bitwise_xor(x, True)
|
||||
def __rmod__(self, x): return self.mod(x, True)
|
||||
|
||||
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
def __ge__(self, x): return (self < x).logical_not()
|
||||
def __le__(self, x): return (self > x).logical_not()
|
||||
|
||||
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
def eq(self, x): return self.ne(x).logical_not()
|
||||
def __ne__(self, x): return self.ne(x)
|
||||
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
||||
|
||||
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
||||
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
||||
def __lshift__(self, x): return self.lshift(x)
|
||||
def __rshift__(self, x): return self.rshift(x)
|
||||
def __rlshift__(self, x): return self.lshift(x, True)
|
||||
def __rrshift__(self, x): return self.rshift(x, True)
|
||||
|
||||
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
||||
def minimum(self, x): return -(-self).maximum(-x)
|
||||
def where(self, x, y):
|
||||
if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y))
|
||||
if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y)
|
||||
raise RuntimeError("where needs at least one UOp arg")
|
||||
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(Ops.RECIP)
|
||||
def sqrt(self): return self.alu(Ops.SQRT)
|
||||
def sin(self): return self.alu(Ops.SIN)
|
||||
def log2(self): return self.alu(Ops.LOG2)
|
||||
def exp2(self): return self.alu(Ops.EXP2)
|
||||
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Sequence
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, dedup, cdiv, cmod, diskcache_put
|
||||
@@ -10,179 +11,6 @@ if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
class FastEnum(IntEnum):
|
||||
def __str__(self): return Enum.__str__(self)
|
||||
@staticmethod
|
||||
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
||||
|
||||
class MathTrait:
|
||||
# required to implement
|
||||
def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
|
||||
def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
|
||||
|
||||
# great functions you get!
|
||||
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
||||
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
||||
def logical_not(self): return self.ne(True)
|
||||
def neg(self):
|
||||
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
||||
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
||||
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
||||
def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
|
||||
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
||||
def bitwise_xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
||||
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
||||
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
||||
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
||||
|
||||
def __neg__(self): return self.neg()
|
||||
|
||||
def __add__(self, x): return self.add(x)
|
||||
def __sub__(self, x): return self.sub(x)
|
||||
def __mul__(self, x): return self.mul(x)
|
||||
def __truediv__(self, x): return self.div(x)
|
||||
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
||||
def __mod__(self, x): return self.mod(x)
|
||||
def __and__(self, x): return self.bitwise_and(x)
|
||||
def __or__(self, x): return self.bitwise_or(x)
|
||||
def __xor__(self, x): return self.bitwise_xor(x)
|
||||
|
||||
def __radd__(self, x): return self.add(x, True)
|
||||
def __rsub__(self, x): return self.sub(x, True)
|
||||
def __rmul__(self, x): return self.mul(x, True)
|
||||
def __rtruediv__(self, x): return self.div(x, True)
|
||||
def __rfloordiv__(self, x): return self.idiv(x, True)
|
||||
def __rand__(self, x): return self.bitwise_and(x, True)
|
||||
def __ror__(self, x): return self.bitwise_or(x, True)
|
||||
def __rxor__(self, x): return self.bitwise_xor(x, True)
|
||||
def __rmod__(self, x): return self.mod(x, True)
|
||||
|
||||
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
def __ge__(self, x): return (self < x).logical_not()
|
||||
def __le__(self, x): return (self > x).logical_not()
|
||||
|
||||
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
def eq(self, x): return self.ne(x).logical_not()
|
||||
def __ne__(self, x): return self.ne(x)
|
||||
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
||||
|
||||
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
||||
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
||||
def __lshift__(self, x): return self.lshift(x)
|
||||
def __rshift__(self, x): return self.rshift(x)
|
||||
def __rlshift__(self, x): return self.lshift(x, True)
|
||||
def __rrshift__(self, x): return self.rshift(x, True)
|
||||
|
||||
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
||||
def minimum(self, x): return -(-self).maximum(-x)
|
||||
def where(self, x, y):
|
||||
if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y))
|
||||
if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y)
|
||||
raise RuntimeError("where needs at least one UOp arg")
|
||||
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
||||
def reciprocal(self): return self.alu(Ops.RECIP)
|
||||
def sqrt(self): return self.alu(Ops.SQRT)
|
||||
def sin(self): return self.alu(Ops.SIN)
|
||||
def log2(self): return self.alu(Ops.LOG2)
|
||||
def exp2(self): return self.alu(Ops.EXP2)
|
||||
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
|
||||
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
|
||||
|
||||
# MetaOps
|
||||
COPY = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
|
||||
# blocks in linearizer
|
||||
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
|
||||
|
||||
# movement ops!
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
|
||||
# misc ops
|
||||
UNROLL = auto(); CONTRACT = auto() # noqa: E702
|
||||
VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
|
||||
DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
|
||||
VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
||||
|
||||
# helper ops
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
|
||||
# load/store before math
|
||||
LOAD = auto(); STORE = auto() # noqa: E702
|
||||
|
||||
# early INDEX
|
||||
INDEX = auto()
|
||||
|
||||
# math ops
|
||||
WMMA = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702
|
||||
XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
# assignment ops
|
||||
ASSIGN = auto()
|
||||
BIND = auto()
|
||||
|
||||
# control flow ops
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto(); GBARRIER = auto() # noqa: E702
|
||||
|
||||
# consts last!
|
||||
VCONST = auto(); CONST = auto() # noqa: E702
|
||||
|
||||
# device
|
||||
DEVICE = auto()
|
||||
MULTI = auto()
|
||||
|
||||
# CUSTOMI is inline
|
||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||
IGNORE = auto(); FUSE = auto() # noqa: E702
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
||||
Ops.SUB, Ops.FDIV, Ops.POW}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
||||
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
|
||||
|
||||
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
|
||||
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
|
||||
|
||||
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
||||
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
||||
|
||||
Meta = {Ops.COPY, Ops.BUFFER_VIEW}
|
||||
|
||||
All = set(Ops)
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user