Files
tinygrad/tinygrad/ops.py
2024-07-31 12:59:38 -07:00

171 lines
8.9 KiB
Python

from __future__ import annotations
from typing import Union, Tuple, Any, List, Dict, Callable
import functools, hashlib, math, operator, ctypes, struct
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import prod, dedup, pretty_print
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(Enum):
"""A -> A (elementwise)"""
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
class BinaryOps(Enum):
"""A + A -> A (elementwise)"""
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
class TernaryOps(Enum):
"""A + A + A -> A (elementwise)"""
WHERE = auto(); MULACC = auto() # noqa: E702
class ReduceOps(Enum):
"""A -> B (reduce)"""
SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class MetaOps(Enum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
# do not preserve f(0) = 0
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
@dataclass(frozen=True)
class MemBuffer:
idx: int
dtype: DType
st: ShapeTracker
@dataclass(frozen=True)
class ConstBuffer:
val: ConstType | Variable
dtype: DType
st: ShapeTracker
@dataclass(frozen=True)
class KernelInfo:
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to EXPAND)
dont_use_locals: bool = False # don't use local indexing
@dataclass(frozen=True, eq=False)
class LazyOp:
op: Op
src: Tuple[LazyOp, ...] = ()
arg: Any = None
def cached_compare(self, x, context):
if id(self) == id(x): return True
if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
if (key := (id(self), id(x))) in context: return context[key]
ret = context[key] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
return ret
def __eq__(self, x): return self.cached_compare(x, context={})
def __repr__(self:LazyOp): return pretty_print(self, lambda x: f'LazyOp({x.op}, arg={x.arg}, src=(%s))')
@functools.cached_property
def dtype(self) -> DType:
if self.op in BufferOps: return self.arg.dtype
if self.op is ReduceOps.WMMA: return self.arg[3] # WMMA can change the type
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if len(self.src) == 0 and self.op in BufferOps: return self.arg.st.shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src]))
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
@functools.cached_property
def hash(self): return hash((self.op, self.src, self.arg))
def __hash__(self): return self.hash
@functools.cached_property
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
def vars(self) -> List[Variable]:
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
# TODO: support non-lazyop
def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
def ne(self, x:LazyOp): return LazyOp(BinaryOps.CMPNE, (self, x))
def eq(self, x:LazyOp): return -self.ne(x)
def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
@staticmethod
def const(val, dtype:DType, shape:Tuple[sint, ...]):
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, dtype, ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)))
# **************** ops in python ****************
def hook_overflow(dv, fxn):
def wfxn(*args):
try: return fxn(*args)
except OverflowError: return dv
return wfxn
python_alu: Dict[Op, Callable] = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
def truncate_fp16(x):
try:
x = float(x)
struct.pack("@e", x)
return x
except OverflowError: return math.copysign(math.inf, x)
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: bfloat16
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(st.shape))
# the living definition of LazyOps
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
assert ast.op is MetaOps.KERNEL, "must be SINK"
sts: Dict[LazyOp, ShapeTracker] = {}
def assert_valid(op:LazyOp, st:ShapeTracker):
if op in sts: return
# restore globals from the two stage reduce
if op.op is BufferOps.LOAD and op.arg.idx < 0:
assert_valid(local_reduce:=op.src[0].src[0], op.arg.st)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op in ReduceOps:
axis = op.arg[-1] if op.op is ReduceOps.WMMA else op.arg
assert isinstance(axis, tuple) and all(isinstance(i, int) for i in axis), f"reduceop must have axis {op.arg}"
st = ShapeTracker.from_shape(reduce_st(sts[op.src[0]], axis))
else:
# movementops are pushed to the edges with LOAD
# elementwise inherits shape
st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
for x in op.src:
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for i, out in enumerate(ast.src):
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
assert_valid(out, out.arg.st)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts