Files
tinygrad/extra/thunder/tiny/tk/tiles.py
2025-12-07 20:10:30 -08:00

273 lines
9.4 KiB
Python

from enum import Enum, auto
import functools
from typing import Callable
from dataclasses import dataclass
from tinygrad.dtype import AddrSpace, DType
from tinygrad.mixin import MathMixin
from tinygrad.uop.ops import UOp, Ops
from extra.thunder.tiny.tk import WARP_THREADS
def unwrap(x):
if hasattr(x, "_uop"): return x._uop
if isinstance(x, (list, tuple)): return type(x)(unwrap(y) for y in x)
if isinstance(x, dict): return {k: unwrap(v) for k,v in x.items()}
return x
def wrap(x, s):
if isinstance(x, UOp): return s.ruop(x)
if isinstance(x, (list, tuple)): return type(x)(wrap(y, s) for y in x)
return x
def autowrap(source_cls, blacklist=None):
if blacklist is None:
blacklist = {
"__init__", "__new__", "__str__", "__del__", "__repr__", "__dict__", "__getattribute__",
"__setattr__", "__delattr__", "__weakref__", "__slots__", "__class__",
"__reduce__", "__reduce_ex__", "__getstate__", "__setstate__", "__hash__"
}
def decorator(cls):
def __getattr__(self, name):
uop = object.__getattribute__(self, "_uop")
val = getattr(uop, name)
if callable(val):
@functools.wraps(val)
def proxy(*args, **kwargs):
return wrap(val(*unwrap(args), **unwrap(kwargs)), self)
return proxy
if name in UOp.__slots__: return val # type: ignore
return wrap(val, self)
cls.__getattr__ = __getattr__
for name in dir(source_cls):
if name in blacklist or not name.startswith("__"): continue
for base in cls.mro():
if base is source_cls: break
if name in base.__dict__: break
else:
original = getattr(source_cls, name)
if callable(original):
def make_proxy(_, func):
def proxy(self, *args, **kwargs):
return wrap(func(self._uop, *unwrap(args), **unwrap(kwargs)), self)
return proxy
setattr(cls, name, make_proxy(name, original))
return cls
return decorator
class TileMathMixin(MathMixin):
def alu(self, op, *src, inner_op=lambda x:x):
assert isinstance(self, (RT, RV))
if len(src) == 0:
if self._uop._shape is None: uop = UOp.alu(self._uop, op)
else: uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op))
elif len(src) == 1:
if self._uop._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0]))))
elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
else:
if isinstance(self, RT) and isinstance(src[0], RV):
match self.layout:
case TileLayout.ROW: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0])))
case TileLayout.COL: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[1], 0])))
else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx])))
else: raise NotImplementedError
return self.ruop(uop)
def const_like(self, b): return b
# override ops that do compute on the src uop
def sub(self, x, reverse=False):
return self.ufix(x).alu(Ops.ADD, self, inner_op=lambda y: -y) if reverse else self.alu(Ops.ADD, self.ufix(x), inner_op=lambda y: -y)
def div(self, x, reverse=False):
return self.ufix(x).alu(Ops.MUL, self, inner_op=lambda y: 1/y) if reverse else self.alu(Ops.MUL, self.ufix(x), inner_op=lambda y: 1/y)
@autowrap(UOp)
class GL:
def __init__(self, uop:UOp, ker):
self._uop, self.ker = uop, ker
def ruop(self, uop:UOp):
return GL(uop, self.ker)
@classmethod
def create(cls, shape, dtype:DType, ker):
uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
return cls(uop, ker)
class TileLayout(Enum):
ROW = auto()
COL = auto()
class VecLayout(Enum):
ORTHO = auto()
@dataclass(frozen=True)
class BaseShape:
rows: int
cols: int
@property
def num_elements(self): return self.rows * self.cols
@property
def elements_per_thread(self): return self.num_elements // WARP_THREADS
@dataclass(frozen=True)
class STBaseShape(BaseShape):
_swizzle: Callable[[UOp, DType], UOp]
bytes_per_thread: Callable[[DType], int]
def swizzle(self, row, col, dtype:DType):
offset = row * self.cols + col
offset *= dtype.itemsize
offset = self._swizzle(offset, dtype)
offset //= dtype.itemsize
return offset
def st_16x16_swizzle(offset:UOp, _): return offset
def st_16x16_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_16X16 = STBaseShape(16, 16, st_16x16_swizzle, st_16x16_bpt)
def st_16x16_swizzled_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
swizzle = ((offset % 512) >> 7) << 3
return offset ^ swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_16x16_swizzled_bpt(dtype:DType):
if dtype.itemsize == 2: return 4
elif dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_16X16_SWIZZLED = STBaseShape(16, 16, st_16x16_swizzled_swizzle, st_16x16_swizzled_bpt)
def st_32x32_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
first_swizzle = ((offset % 1024) >> 9) << 5
second_swizzle = ((offset % 2048) >> 10) << 4
return offset ^ first_swizzle ^ second_swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_32x32_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_32X32 = STBaseShape(32, 32, st_32x32_swizzle, st_32x32_bpt)
def st_16x32_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
swizzle = ((offset % 1024) >> 9) << 5
return offset ^ swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_16x32_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_16X32 = STBaseShape(16, 32, st_16x32_swizzle, st_16x32_bpt)
def st_32x16_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
swizzle = ((offset % 1024) >> 9) << 4
return offset ^ swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_32x16_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_32X16 = STBaseShape(32, 16, st_32x16_swizzle, st_32x16_bpt)
@autowrap(UOp)
class ST:
def __init__(self, uop:UOp, rows:int, cols:int, layout:TileLayout, base_shape:STBaseShape, ker):
self._uop, self.rows, self.cols, self.layout, self.base_shape, self.ker = uop, rows, cols, layout, base_shape, ker
def ruop(self, uop:UOp):
return ST(uop, self.rows, self.cols, self.layout, self.base_shape, self.ker)
@classmethod
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:STBaseShape, ker):
rows = shape[-2]
cols = shape[-1]
assert rows % base_shape.rows == 0
assert cols % base_shape.cols == 0
assert cols % base_shape.elements_per_thread == 0
height = rows // base_shape.rows
width = cols // base_shape.cols
uop = ker.alloc(shape[:-2] + (height, width, base_shape.rows, base_shape.cols), dtype, AddrSpace.LOCAL)
return cls(uop, rows, cols, layout, base_shape, ker)
def swizzle(self, row, col):
swizzled_offset = self.base_shape.swizzle(row, col, self._uop.dtype.base.scalar())
row = swizzled_offset // self.base_shape.cols
col = swizzled_offset % self.base_shape.cols
return row, col
@dataclass(frozen=True)
class RTBaseShape(BaseShape):
stride: int
@property
def num_strides(self):
return self.elements_per_thread // self.stride
RT_16X16 = RTBaseShape(rows=16, cols=16, stride=4)
RT_32X32 = RTBaseShape(rows=32, cols=32, stride=4)
RT_32X32_8 = RTBaseShape(rows=32, cols=32, stride=8)
RT_16X32 = RTBaseShape(rows=16, cols=32, stride=8)
RT_32X16 = RTBaseShape(rows=32, cols=16, stride=8)
RT_32X16_4 = RTBaseShape(rows=32, cols=16, stride=4)
RT_16X32_4 = RTBaseShape(rows=16, cols=32, stride=4)
@autowrap(UOp)
class RT(TileMathMixin):
def __init__(self, uop:UOp, layout:TileLayout, base_shape:RTBaseShape, ker):
self._uop, self.layout, self.base_shape, self.ker = uop, layout, base_shape, ker
def ruop(self, uop:UOp):
return RT(uop, self.layout, self.base_shape, self.ker)
@classmethod
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:RTBaseShape, ker):
assert len(shape) == 2
assert shape[0] % base_shape.rows == 0
assert shape[1] % base_shape.cols == 0
height = shape[0] // base_shape.rows
width = shape[1] // base_shape.cols
uop = ker.alloc((height, width, base_shape.elements_per_thread), dtype, AddrSpace.REG)
return cls(uop, layout, base_shape, ker)
@autowrap(UOp)
class RV(TileMathMixin):
def __init__(self, uop:UOp, length:int, layout:VecLayout, base_shape:RTBaseShape, ker):
self._uop, self.ker = uop, ker
self.length, self.layout, self.base_shape = length, layout, base_shape
def ruop(self, uop:UOp):
return RV(uop, self.length, self.layout, self.base_shape, self.ker)
@classmethod
def create(cls, length, dtype:DType, layout:VecLayout, base_shape:RTBaseShape, ker):
tiles = length // base_shape.rows
match layout:
case VecLayout.ORTHO:
inner_dim = 1
outer_dim = tiles
uop = ker.alloc((outer_dim, inner_dim), dtype, AddrSpace.REG)
return RV(uop, length, layout, base_shape, ker)
ALL_TILES = UOp | GL | ST | RT | RV