From c28eceaf44d3661ab0284934998fcce7ed72f602 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 9 Jun 2025 21:17:35 -0700 Subject: [PATCH] move to mathtraits.py (#10742) --- tinygrad/uop/__init__.py | 100 ++++++++++++++++++++- tinygrad/uop/mathtraits.py | 78 ++++++++++++++++ tinygrad/uop/ops.py | 176 +------------------------------------ 3 files changed, 179 insertions(+), 175 deletions(-) create mode 100644 tinygrad/uop/mathtraits.py diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index efae111520..225d0c012b 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -1 +1,99 @@ -from tinygrad.uop.ops import UOp, Ops # noqa: F401 \ No newline at end of file +from enum import auto, IntEnum, Enum + +# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses +class FastEnum(IntEnum): + def __str__(self): return Enum.__str__(self) + @staticmethod + def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]]) + +# the order of these Ops controls the order of the toposort +class Ops(FastEnum): + # uops that aren't rendered + SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702 + + # MetaOps + COPY = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702 + + # blocks in linearizer + BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702 + + # movement ops! + RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702 + + # misc ops + UNROLL = auto(); CONTRACT = auto() # noqa: E702 + VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702 + DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702 + VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702 + + # reduce + REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702 + + # helper ops + GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702 + + # UnaryOps + CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 + + # load/store before math + LOAD = auto(); STORE = auto() # noqa: E702 + + # early INDEX + INDEX = auto() + + # math ops + WMMA = auto() + + # BinaryOps + ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702 + XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702 + + # TernaryOps + WHERE = auto(); MULACC = auto() # noqa: E702 + + # assignment ops + ASSIGN = auto() + BIND = auto() + + # control flow ops + BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto(); GBARRIER = auto() # noqa: E702 + + # consts last! + VCONST = auto(); CONST = auto() # noqa: E702 + + # device + DEVICE = auto() + MULTI = auto() + + # CUSTOMI is inline + CUSTOM = auto(); CUSTOMI = auto() # noqa: E702 + IGNORE = auto(); FUSE = auto() # noqa: E702 + +class GroupOp: + Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} + Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, + Ops.SUB, Ops.FDIV, Ops.POW} + Ternary = {Ops.WHERE, Ops.MULACC} + ALU = set.union(Unary, Binary, Ternary) + + Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} + Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP} + + Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR} + Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART} + + # BinaryOps that can be flipped + Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR} + + # BinaryOps where f(f(a,b),c) = f(a,f(b,c)) + Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX} + + # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence + Idempotent = {Ops.OR, Ops.AND, Ops.MAX} + + # do not preserve f(0) = 0 + UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW} + + Meta = {Ops.COPY, Ops.BUFFER_VIEW} + + All = set(Ops) diff --git a/tinygrad/uop/mathtraits.py b/tinygrad/uop/mathtraits.py new file mode 100644 index 0000000000..81e4d64111 --- /dev/null +++ b/tinygrad/uop/mathtraits.py @@ -0,0 +1,78 @@ +from tinygrad.uop import Ops +from tinygrad.helpers import T +from tinygrad.dtype import dtypes + +class MathTrait: + # required to implement + def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError + def const_like(self:T, b) -> T: raise NotImplementedError + + # great functions you get! + def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x + def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x)) + def logical_not(self): return self.ne(True) + def neg(self): + if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}") + return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1) + def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse) + def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse) + def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse) + def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse) + def bitwise_xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse) + def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse) + def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse) + def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) + def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) + + def __neg__(self): return self.neg() + + def __add__(self, x): return self.add(x) + def __sub__(self, x): return self.sub(x) + def __mul__(self, x): return self.mul(x) + def __truediv__(self, x): return self.div(x) + def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv + def __mod__(self, x): return self.mod(x) + def __and__(self, x): return self.bitwise_and(x) + def __or__(self, x): return self.bitwise_or(x) + def __xor__(self, x): return self.bitwise_xor(x) + + def __radd__(self, x): return self.add(x, True) + def __rsub__(self, x): return self.sub(x, True) + def __rmul__(self, x): return self.mul(x, True) + def __rtruediv__(self, x): return self.div(x, True) + def __rfloordiv__(self, x): return self.idiv(x, True) + def __rand__(self, x): return self.bitwise_and(x, True) + def __ror__(self, x): return self.bitwise_or(x, True) + def __rxor__(self, x): return self.bitwise_xor(x, True) + def __rmod__(self, x): return self.mod(x, True) + + def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x)) + def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self) + def __ge__(self, x): return (self < x).logical_not() + def __le__(self, x): return (self > x).logical_not() + + def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x)) + def eq(self, x): return self.ne(x).logical_not() + def __ne__(self, x): return self.ne(x) + # NOTE: __eq__ isn't overridden, and means the same thing as is by default + + def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse) + def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse) + def __lshift__(self, x): return self.lshift(x) + def __rshift__(self, x): return self.rshift(x) + def __rlshift__(self, x): return self.lshift(x, True) + def __rrshift__(self, x): return self.rshift(x, True) + + def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) + def minimum(self, x): return -(-self).maximum(-x) + def where(self, x, y): + if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y)) + if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y) + raise RuntimeError("where needs at least one UOp arg") + def threefry(self, seed): return self.alu(Ops.THREEFRY, seed) + def reciprocal(self): return self.alu(Ops.RECIP) + def sqrt(self): return self.alu(Ops.SQRT) + def sin(self): return self.alu(Ops.SIN) + def log2(self): return self.alu(Ops.LOG2) + def exp2(self): return self.alu(Ops.EXP2) + def pow(self, x): return self.alu(Ops.POW, self.ufix(x)) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8c9071fe3e..26770784cc 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1,8 +1,9 @@ from __future__ import annotations from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Sequence import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref -from enum import auto, IntEnum, Enum from dataclasses import dataclass, field +from tinygrad.uop import Ops, GroupOp +from tinygrad.uop.mathtraits import MathTrait from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten from tinygrad.helpers import PICKLE_BUFFERS, dedup, cdiv, cmod, diskcache_put @@ -10,179 +11,6 @@ if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer, MultiBuffer -# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses -class FastEnum(IntEnum): - def __str__(self): return Enum.__str__(self) - @staticmethod - def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]]) - -class MathTrait: - # required to implement - def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError - def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError - - # great functions you get! - def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x - def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x)) - def logical_not(self): return self.ne(True) - def neg(self): - if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}") - return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1) - def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse) - def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse) - def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse) - def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse) - def bitwise_xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse) - def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse) - def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse) - def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) - def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) - - def __neg__(self): return self.neg() - - def __add__(self, x): return self.add(x) - def __sub__(self, x): return self.sub(x) - def __mul__(self, x): return self.mul(x) - def __truediv__(self, x): return self.div(x) - def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv - def __mod__(self, x): return self.mod(x) - def __and__(self, x): return self.bitwise_and(x) - def __or__(self, x): return self.bitwise_or(x) - def __xor__(self, x): return self.bitwise_xor(x) - - def __radd__(self, x): return self.add(x, True) - def __rsub__(self, x): return self.sub(x, True) - def __rmul__(self, x): return self.mul(x, True) - def __rtruediv__(self, x): return self.div(x, True) - def __rfloordiv__(self, x): return self.idiv(x, True) - def __rand__(self, x): return self.bitwise_and(x, True) - def __ror__(self, x): return self.bitwise_or(x, True) - def __rxor__(self, x): return self.bitwise_xor(x, True) - def __rmod__(self, x): return self.mod(x, True) - - def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x)) - def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self) - def __ge__(self, x): return (self < x).logical_not() - def __le__(self, x): return (self > x).logical_not() - - def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x)) - def eq(self, x): return self.ne(x).logical_not() - def __ne__(self, x): return self.ne(x) - # NOTE: __eq__ isn't overridden, and means the same thing as is by default - - def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse) - def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse) - def __lshift__(self, x): return self.lshift(x) - def __rshift__(self, x): return self.rshift(x) - def __rlshift__(self, x): return self.lshift(x, True) - def __rrshift__(self, x): return self.rshift(x, True) - - def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) - def minimum(self, x): return -(-self).maximum(-x) - def where(self, x, y): - if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y)) - if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y) - raise RuntimeError("where needs at least one UOp arg") - def threefry(self, seed): return self.alu(Ops.THREEFRY, seed) - def reciprocal(self): return self.alu(Ops.RECIP) - def sqrt(self): return self.alu(Ops.SQRT) - def sin(self): return self.alu(Ops.SIN) - def log2(self): return self.alu(Ops.LOG2) - def exp2(self): return self.alu(Ops.EXP2) - def pow(self, x): return self.alu(Ops.POW, self.ufix(x)) - -# the order of these Ops controls the order of the toposort -class Ops(FastEnum): - # uops that aren't rendered - SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702 - - # MetaOps - COPY = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702 - - # blocks in linearizer - BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702 - - # movement ops! - RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702 - - # misc ops - UNROLL = auto(); CONTRACT = auto() # noqa: E702 - VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702 - DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702 - VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702 - - # reduce - REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702 - - # helper ops - GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702 - - # UnaryOps - CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 - - # load/store before math - LOAD = auto(); STORE = auto() # noqa: E702 - - # early INDEX - INDEX = auto() - - # math ops - WMMA = auto() - - # BinaryOps - ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702 - XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702 - - # TernaryOps - WHERE = auto(); MULACC = auto() # noqa: E702 - - # assignment ops - ASSIGN = auto() - BIND = auto() - - # control flow ops - BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto(); GBARRIER = auto() # noqa: E702 - - # consts last! - VCONST = auto(); CONST = auto() # noqa: E702 - - # device - DEVICE = auto() - MULTI = auto() - - # CUSTOMI is inline - CUSTOM = auto(); CUSTOMI = auto() # noqa: E702 - IGNORE = auto(); FUSE = auto() # noqa: E702 - -class GroupOp: - Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} - Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, - Ops.SUB, Ops.FDIV, Ops.POW} - Ternary = {Ops.WHERE, Ops.MULACC} - ALU = set.union(Unary, Binary, Ternary) - - Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} - Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP} - - Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR} - Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART} - - # BinaryOps that can be flipped - Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR} - - # BinaryOps where f(f(a,b),c) = f(a,f(b,c)) - Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX} - - # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence - Idempotent = {Ops.OR, Ops.AND, Ops.MAX} - - # do not preserve f(0) = 0 - UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW} - - Meta = {Ops.COPY, Ops.BUFFER_VIEW} - - All = set(Ops) - # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)