mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
573 lines
28 KiB
Python
573 lines
28 KiB
Python
from __future__ import annotations
|
|
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
|
|
import math, operator, ctypes, struct, functools, hashlib
|
|
from enum import Enum, auto
|
|
from dataclasses import dataclass
|
|
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
|
|
from tinygrad.helpers import pretty_print, prod
|
|
from tinygrad.shape.symbolic import Variable, sint
|
|
if TYPE_CHECKING:
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
|
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
|
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
|
class UnaryOps(Enum):
|
|
"""A -> A (elementwise)"""
|
|
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
|
|
class BinaryOps(Enum):
|
|
"""A + A -> A (elementwise)"""
|
|
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
|
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
|
|
class TernaryOps(Enum):
|
|
"""A + A + A -> A (elementwise)"""
|
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
|
class ReduceOps(Enum):
|
|
"""A -> B (reduce)"""
|
|
SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702
|
|
class MetaOps(Enum):
|
|
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
|
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps]
|
|
|
|
T = TypeVar("T")
|
|
class MathTrait:
|
|
# required to implement
|
|
def alu(self:T, arg:Union[UnaryOps, BinaryOps, TernaryOps], *src) -> T: raise NotImplementedError
|
|
def const_like(self, b:ConstType|Variable): raise NotImplementedError
|
|
|
|
# great functions you get!
|
|
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
|
def __neg__(self): return self.ne(True) if getattr(self, 'dtype', None) == dtypes.bool else self*(-1)
|
|
def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
|
def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
|
|
def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
|
|
def __rsub__(self, x): return self.ufix(x).alu(BinaryOps.ADD, -self)
|
|
def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
|
|
def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
|
|
def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
|
|
def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
|
|
def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
|
|
def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
|
|
def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
|
|
def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
|
|
def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
|
|
def eq(self, x): return -self.ne(x)
|
|
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
|
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
|
|
def ge(self, x): return (-self).lt(-x+1)
|
|
def max(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
|
|
def min(self, x): return -(-self).max(-x)
|
|
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
|
def threefry(self, seed): return self.alu(BinaryOps.THREEFRY, seed)
|
|
def recip(self): return self.alu(UnaryOps.RECIP)
|
|
def sqrt(self): return self.alu(UnaryOps.SQRT)
|
|
def sin(self): return self.alu(UnaryOps.SIN)
|
|
def log2(self): return self.alu(UnaryOps.LOG2)
|
|
def exp2(self): return self.alu(UnaryOps.EXP2)
|
|
|
|
# do not preserve f(0) = 0
|
|
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
|
|
|
REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.PROD:BinaryOps.MUL, ReduceOps.MAX:BinaryOps.MAX}
|
|
|
|
# https://en.wikipedia.org/wiki/Identity_element
|
|
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
|
|
|
# the order of these UOps controls the order of the toposort
|
|
class UOps(Enum):
|
|
# uops that aren't rendered
|
|
SINK = auto()
|
|
"""
|
|
Holds `UOps.STORE`. SINK defines the AST for a Kernel.
|
|
|
|
- **`dtype`**: `None`
|
|
- **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed.
|
|
- **`arg`**: `Optional[KernelInfo]`
|
|
|
|
NOTE: `ScheduleItem` ASTs do not have the `KernelInfo` arg, `Kernel` inserts this to the SINK later.
|
|
"""
|
|
EXT = auto()
|
|
"""
|
|
Holds a single MetaOp. EXT UOps do not need a Kernel.
|
|
|
|
- **`dtype`**: Output DType
|
|
- **`src`**: `Tuple[]`
|
|
- **`arg`**: (`MetaOps.CUSTOM | MetaOps.COPY | MetaOps.EMPTY | MetaOps.VIEW`, LazyBuffer arg)
|
|
"""
|
|
EXPAND = auto()
|
|
CONTRACT = auto()
|
|
SHAPETRACKER = auto()
|
|
"""
|
|
Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.CONST`.
|
|
|
|
- **`dtype`**: `None`
|
|
- **`src`**: `Tuple[]`
|
|
- **`arg`**: `ShapeTracker`
|
|
"""
|
|
SWIZZLE = auto()
|
|
"""
|
|
Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST,
|
|
the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph.
|
|
|
|
Example:
|
|
```python
|
|
a = Tensor.empty(32, 32)
|
|
first_reduce = a.sum()
|
|
output = (a + first_reduce).sum()
|
|
```
|
|
`first_reduce` must broadcast to `(32, 32)` before ADD. We UOp this as:
|
|
|
|
```
|
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
|
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
|
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
|
x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
|
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
|
x3,
|
|
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
|
```
|
|
|
|
The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD:
|
|
|
|
```diff
|
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
|
- UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
|
- UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
|
- UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
|
- x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
|
- UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
|
+ UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=(
|
|
+ UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
|
+ x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
|
+ UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
|
- x3,
|
|
- UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
|
+ x2,
|
|
+ UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
|
|
|
```
|
|
|
|
NOTE: Pushing a SWIZZLE through a reduce changes the axis.
|
|
|
|
NOTE: Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to `(32, 32, 1, 1)` above.
|
|
|
|
- **`dtype`**: Output DType
|
|
- **`src`**: `Tuple[UOp]`, a single UOp to swizzle.
|
|
- **`arg`**: ShapeTracker
|
|
""" # noqa E501
|
|
DEFINE_GLOBAL = auto()
|
|
DEFINE_VAR = auto()
|
|
DEFINE_LOCAL = auto()
|
|
DEFINE_ACC = auto()
|
|
CONST = auto()
|
|
"""
|
|
Defines a single scalar constant value.
|
|
|
|
- **`dtype`**: The scalar DType of the value.
|
|
|
|
- **`src`**:
|
|
The scheduler creates a CONST with a single SHAPETRACKER UOp src: `Tuple[UOp]`.
|
|
|
|
The Lowerer replaces the SHAPETRACKER with an empty src.
|
|
It uses the ShapeTracker valid to create a `WHERE` UOp mask with sources: `(The actual CONST UOp, CONST 0, 0.0 or False)`
|
|
|
|
- **`arg`**: The value.
|
|
"""
|
|
SPECIAL = auto()
|
|
NOOP = auto()
|
|
GEP = auto()
|
|
# math ops
|
|
CAST = auto()
|
|
"""
|
|
- **`dtype`**: The casted scalar DType
|
|
- **`src`**: `Tuple[UOp]`
|
|
- **`arg`**: `None`
|
|
"""
|
|
BITCAST = auto()
|
|
"""
|
|
- **`dtype`**: The bitcasted scalar DType
|
|
- **`src`**: `Tuple[UOp]`
|
|
- **`arg`**: `None`
|
|
"""
|
|
VECTORIZE = auto()
|
|
"""
|
|
- **`dtype`**: The upcasted vector DType
|
|
- **`src`**: `Tuple[UOp, ...]`
|
|
- **`arg`**: `None`
|
|
|
|
NOTE: Length of sources must match `dtype.count`
|
|
"""
|
|
ALU = auto()
|
|
"""
|
|
- **`dtype`**: Output DType
|
|
- **`src`**: `Tuple[UOp] | Tuple[UOp, UOp] | Tuple[UOp, UOp, UOp]`
|
|
- **`arg`**: `UnaryOps | BinaryOps | TernaryOps`
|
|
"""
|
|
REDUCE = auto()
|
|
REDUCE_AXIS = auto()
|
|
"""
|
|
- **`dtype`**: Output DType
|
|
- **`src`**: Input to reduce `Tuple[UOp]`
|
|
- **`arg`**: `(BinaryOps.ADD | BinaryOps.MUL | BinaryOps.MAX, Tuple[int, ...])`
|
|
"""
|
|
WMMA = auto()
|
|
# memory/assignment ops
|
|
LOAD = auto()
|
|
"""
|
|
- **`dtype`**: Output DType
|
|
- **`src`**:
|
|
|
|
The scheduler and Kernel create LOADs with a SHAPETRACKER uop in src.
|
|
|
|
- Normal LOAD: `Tuple[UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_GLOBAL`.
|
|
- SHAPETRACKER UOp.
|
|
|
|
- Local LOAD: `Tuple[UOp, UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_LOCAL`.
|
|
- SHAPETRACKER UOp.
|
|
- Local UOps.STORE to the same local buffer. We will barrier this later.
|
|
|
|
The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the LOAD if needed.
|
|
|
|
- Normal LOAD: `Tuple[UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_GLOBAL`.
|
|
- Indexing UOp, can only return `dtypes.int32`.
|
|
- Gated LOAD: `Tuple[UOp, UOp, UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_GLOBAL`.
|
|
- Indexing UOp, can only return `dtypes.int32`.
|
|
- Gate UOp, can only return `dtypes.bool`.
|
|
- Value if gate is `False`, can only be a `UOps.CONST` with arg 0, 0.0 or `False`.
|
|
- Barriered LOAD: `Tuple[UOp, UOp, UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_LOCAL`.
|
|
- Indexing UOp, can only return `dtypes.int32`.
|
|
- Gate UOp, can only return `dtypes.bool`.
|
|
- Barrier UOp `UOps.BARRIER`.
|
|
- **`arg`**: `None`
|
|
"""
|
|
STORE = auto()
|
|
"""
|
|
- **`dtype`**: `None`
|
|
- **`src`**:
|
|
|
|
Similar to LOAD, the scheduler and Kernel create STOREs with a SHAPETRACKER uop in src:
|
|
|
|
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
|
|
- SHAPETRACKER UOp.
|
|
- Value to store.
|
|
|
|
The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the STORE if needed.
|
|
|
|
- Normal STORE: `Tuple[UOp, UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
|
|
- Indexing Op, can only return `dtypes.int32`.
|
|
- Value to store.
|
|
- Gated STORE: `Tuple[UOp, UOp, UOp, UOp]`
|
|
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
|
|
- Indexing UOp, can only return `dtypes.int32`.
|
|
- Value to store.
|
|
- Gate UOp, can only return `dtypes.bool`. We rewrite this to an IF block in the end.
|
|
- **`arg`**: `None`
|
|
"""
|
|
ASSIGN = auto()
|
|
# control flow ops
|
|
BARRIER = auto()
|
|
"""
|
|
Inserts a warp sync between local stores and local loads.
|
|
|
|
- **`dtype`**: `None`
|
|
- **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed.
|
|
- **`arg`**: `None`
|
|
"""
|
|
IF = auto()
|
|
"""
|
|
Gates a single STORE to global memory. The IF block could also contain additional UOps the STORE depends on.
|
|
|
|
- **`dtype`**: `None`
|
|
- **`src`**:
|
|
`Tuple[UOp, UOp]`
|
|
- Gate UOp, can only return `dtypes.bool`
|
|
- The second UOp starts the gate block; All of its children are gated until the final STORE.
|
|
- **`arg`**: `None`
|
|
|
|
For example, a local reduce must only run on one thread.
|
|
|
|
The STORE's IF gate:
|
|
```
|
|
UOp(UOps.IF, src=(
|
|
UOp(UOps.ALU, dtypes.bool, (...), BinaryOps.CMPNE),
|
|
UOp(UOps.BARRIER, None, (...))))
|
|
```
|
|
The kernel:
|
|
```
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
if (lidx0!=1) {
|
|
int acc1 = 0;
|
|
for (int ridx1 = 0; ridx1 < 16; ridx1++) {
|
|
int val1 = temp1[ridx1];
|
|
acc1 = (acc1+val1);
|
|
}
|
|
data0[0] = acc1;
|
|
}
|
|
```
|
|
"""
|
|
RANGE = auto()
|
|
# ops that are not graph nodes
|
|
ENDRANGE = auto()
|
|
ENDIF = auto()
|
|
|
|
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
|
|
|
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class UOp(MathTrait):
|
|
op: UOps
|
|
dtype: Optional[DType] = None
|
|
src: Tuple[UOp, ...] = tuple()
|
|
arg: Any = None
|
|
def __hash__(self): return id(self)
|
|
@functools.cached_property
|
|
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
|
|
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
|
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0].expr) if self.op is not UOps.ALU else \
|
|
self.arg.value, self.dtype, self.src)
|
|
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
|
@functools.cached_property
|
|
def key(self) -> bytes:
|
|
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
|
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
|
def argstr(self):
|
|
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
|
|
def commutative(self) -> bool:
|
|
return (self.op is UOps.ALU and \
|
|
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
|
# *** uop syntactic sugar
|
|
@property
|
|
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
|
|
@property
|
|
def st_arg(self) -> ShapeTracker:
|
|
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
|
ret = self.src[self.st_loc]
|
|
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
|
return ret.arg
|
|
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
|
|
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
|
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
|
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
|
|
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
|
|
def sconst_like(self, b:ConstType|Variable): return type(self).const(self.dtype.scalar() if self.dtype is not None else None, b)
|
|
@classmethod
|
|
@functools.lru_cache(None)
|
|
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b)
|
|
@classmethod
|
|
def _const(cls, dtype:Optional[DType], b:ConstType|Variable):
|
|
# TODO: fix dtype of b.max after Variable is just an UOp
|
|
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
|
|
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
|
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
|
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
|
def alu(self, arg, *src:UOp):
|
|
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
|
|
@classmethod
|
|
def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src)
|
|
@classmethod
|
|
def store(cls, *src:UOp): return cls(UOps.STORE, None, src)
|
|
@functools.cached_property
|
|
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
|
@property # parents with self
|
|
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
|
@functools.cached_property
|
|
def full_shape(self) -> Tuple[sint, ...]:
|
|
if self.op is UOps.SHAPETRACKER: return self.arg.shape
|
|
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
|
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
|
|
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
|
def variables(self) -> List[Variable]:
|
|
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
|
return sorted(set.union(*st_vars, set([x.arg[0] for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
|
|
def const_factor(self) -> int:
|
|
"""largest known int that divides self"""
|
|
if self.op is UOps.CONST: return self.arg
|
|
if self.op is UOps.ALU:
|
|
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
|
|
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
|
return 1
|
|
def divides(self, v) -> Optional[UOp]:
|
|
if v==1: return self
|
|
if self.op is UOps.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
|
if self.op is UOps.ALU:
|
|
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
|
if self.arg is BinaryOps.MUL:
|
|
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
|
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
|
return None # generic None if we aren't sure
|
|
@property
|
|
def vmin(self) -> UOp:
|
|
return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.min(cast(DType, self.dtype)))
|
|
@property
|
|
def vmax(self) -> UOp:
|
|
return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.max(cast(DType, self.dtype)))
|
|
@functools.cached_property
|
|
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
|
# NOTE: returned UOp is assumed to be CONST
|
|
# TODO: fix DEFINE_VAR arg in tests and remove checking len(self.arg)
|
|
if self.op is UOps.DEFINE_VAR and self.arg and len(self.arg) > 1: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None
|
|
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
|
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
|
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
|
if self.op is UOps.CONST: return self, self
|
|
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
|
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
|
if self.arg is BinaryOps.ADD: return self.sconst_like(s0.vmin.arg+s1.vmin.arg), self.sconst_like(s0.vmax.arg+s1.vmax.arg)
|
|
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
|
|
# handle at lease one is non-negative
|
|
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
|
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
|
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
|
return self.sconst_like(Lmin*Rmin), self.sconst_like(Lmax*Rmax)
|
|
if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst_like(0), self.sconst_like(s1.vmax.arg-1)
|
|
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
|
if s1.arg > 0: return self.sconst_like(s0.vmin.arg//s1.arg), self.sconst_like(s0.vmax.arg//s1.arg)
|
|
if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg))
|
|
if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg))
|
|
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg<s1.vmin.arg), UOp.const(dtypes.bool, s0.vmin.arg<s1.vmax.arg))
|
|
return None, None
|
|
|
|
@dataclass(frozen=True)
|
|
class KernelInfo:
|
|
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
|
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
|
|
dont_use_locals: bool = False # don't use local indexing
|
|
|
|
# ***** ops in python *****
|
|
|
|
def hook_overflow(dv, fxn):
|
|
def wfxn(*args):
|
|
try: return fxn(*args)
|
|
except OverflowError: return dv
|
|
return wfxn
|
|
|
|
python_alu: Dict[Op, Callable] = {
|
|
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
|
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
|
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
|
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
|
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
|
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
|
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
|
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
|
|
|
def truncate_fp16(x):
|
|
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
|
except OverflowError: return math.copysign(math.inf, x)
|
|
|
|
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
|
# TODO: bfloat16
|
|
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
|
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
|
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
|
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
|
|
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
|
|
|
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
|
|
|
def uop_alu_resolve(u:UOp) -> sint:
|
|
if u.op is UOps.CONST: return u.arg
|
|
if u.op is UOps.DEFINE_VAR: return u.arg[0]
|
|
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
|
|
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
|
|
|
# ***** uop type spec *****
|
|
|
|
def type_verify(uops):
|
|
for u in uops:
|
|
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
|
if uop is UOps.DEFINE_LOCAL: assert isinstance(dtype, PtrDType), f"invalid dtype for local buffer {dtype}"
|
|
if uop is UOps.DEFINE_GLOBAL: assert isinstance(dtype, (PtrDType, ImageDType)), f"invalid dtype for global buffer {dtype}"
|
|
if isinstance(dtype, ImageDType): assert uop is UOps.DEFINE_GLOBAL, f"{uop} can't be image"
|
|
if uop is UOps.SHAPETRACKER: assert len(src) == 0, f"SHAPETRACKER must only define a ShapeTracker arg {uop}"
|
|
if uop is UOps.REDUCE_AXIS: assert isinstance(arg, tuple) and len(arg) == 2 and arg[0] in BinaryOps, f"invalid arg for REDUCE_AXIS {arg}"
|
|
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
|
if uop is UOps.CONST:
|
|
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
|
|
# TODO: intermediate CONST of Variable is DEFINE_VAR
|
|
assert (isinstance(arg, Variable) and u.src) or (type(arg) is type(dtypes.as_const(arg, dtype))), f"type of {arg=} does not match {dtype}"
|
|
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
|
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
|
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
|
if uop is UOps.VECTORIZE:
|
|
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
|
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
|
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
|
|
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
|
if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool
|
|
if uop is UOps.STORE:
|
|
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
|
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
|
if uop is UOps.ALU:
|
|
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
|
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
|
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
|
|
assert dtype == bd, f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
|
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
|
elif arg is BinaryOps.IDIV:
|
|
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
|
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
|
|
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
|
# the distance to shift isn't typechecked
|
|
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
|
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
|
elif arg == TernaryOps.WHERE:
|
|
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
|
|
assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
|
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
|
|
|
# ***** uop helpers *****
|
|
|
|
def print_uops(uops:List[UOp]):
|
|
for i,u in enumerate(uops):
|
|
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
|
|
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
|
|
|
|
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
|
flops: sint = 0
|
|
mem: sint = 0
|
|
mults: sint = 1
|
|
mult_stack: List[sint] = []
|
|
dont_count: Set[UOp] = set()
|
|
if ignore_indexing:
|
|
for u in uops:
|
|
if u.op is UOps.LOAD:
|
|
dont_count = dont_count.union(u.src[1].sparents)
|
|
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
|
elif u.op is UOps.STORE:
|
|
dont_count = dont_count.union(u.src[1].sparents)
|
|
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
|
elif u.op is UOps.IF:
|
|
dont_count = dont_count.union(u.src[0].sparents)
|
|
for u in uops:
|
|
if u.op is UOps.RANGE:
|
|
mult_stack.append(mults)
|
|
mults *= uop_alu_resolve(u.src[1] - u.src[0])
|
|
elif u.op is UOps.ENDRANGE:
|
|
mults = mult_stack.pop(-1)
|
|
elif u.op is UOps.SPECIAL:
|
|
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
|
elif u.op is UOps.LOAD:
|
|
assert u.dtype is not None
|
|
mem += u.dtype.itemsize * mults
|
|
elif u.op is UOps.STORE:
|
|
assert u.src[2].dtype is not None
|
|
mem += u.src[2].dtype.itemsize * mults
|
|
elif u.op is UOps.ALU and u not in dont_count:
|
|
assert u.dtype is not None
|
|
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
|
elif u.op is UOps.WMMA and u not in dont_count:
|
|
assert u.arg[1] is not None
|
|
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
|
return flops, mem
|