mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
273 lines
9.4 KiB
Python
273 lines
9.4 KiB
Python
from enum import Enum, auto
|
|
import functools
|
|
from typing import Callable
|
|
from dataclasses import dataclass
|
|
from tinygrad.dtype import AddrSpace, DType
|
|
from tinygrad.mixin import MathMixin
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
|
|
from extra.thunder.tiny.tk import WARP_THREADS
|
|
|
|
def unwrap(x):
|
|
if hasattr(x, "_uop"): return x._uop
|
|
if isinstance(x, (list, tuple)): return type(x)(unwrap(y) for y in x)
|
|
if isinstance(x, dict): return {k: unwrap(v) for k,v in x.items()}
|
|
return x
|
|
|
|
def wrap(x, s):
|
|
if isinstance(x, UOp): return s.ruop(x)
|
|
if isinstance(x, (list, tuple)): return type(x)(wrap(y, s) for y in x)
|
|
return x
|
|
|
|
def autowrap(source_cls, blacklist=None):
|
|
if blacklist is None:
|
|
blacklist = {
|
|
"__init__", "__new__", "__str__", "__del__", "__repr__", "__dict__", "__getattribute__",
|
|
"__setattr__", "__delattr__", "__weakref__", "__slots__", "__class__",
|
|
"__reduce__", "__reduce_ex__", "__getstate__", "__setstate__", "__hash__"
|
|
}
|
|
|
|
def decorator(cls):
|
|
def __getattr__(self, name):
|
|
uop = object.__getattribute__(self, "_uop")
|
|
val = getattr(uop, name)
|
|
if callable(val):
|
|
@functools.wraps(val)
|
|
def proxy(*args, **kwargs):
|
|
return wrap(val(*unwrap(args), **unwrap(kwargs)), self)
|
|
return proxy
|
|
if name in UOp.__slots__: return val # type: ignore
|
|
return wrap(val, self)
|
|
cls.__getattr__ = __getattr__
|
|
|
|
for name in dir(source_cls):
|
|
if name in blacklist or not name.startswith("__"): continue
|
|
|
|
for base in cls.mro():
|
|
if base is source_cls: break
|
|
if name in base.__dict__: break
|
|
else:
|
|
original = getattr(source_cls, name)
|
|
if callable(original):
|
|
def make_proxy(_, func):
|
|
def proxy(self, *args, **kwargs):
|
|
return wrap(func(self._uop, *unwrap(args), **unwrap(kwargs)), self)
|
|
return proxy
|
|
setattr(cls, name, make_proxy(name, original))
|
|
|
|
return cls
|
|
return decorator
|
|
|
|
class TileMathMixin(MathMixin):
|
|
def alu(self, op, *src, inner_op=lambda x:x):
|
|
assert isinstance(self, (RT, RV))
|
|
if len(src) == 0:
|
|
if self._uop._shape is None: uop = UOp.alu(self._uop, op)
|
|
else: uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op))
|
|
elif len(src) == 1:
|
|
if self._uop._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
|
|
elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0]))))
|
|
elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
|
|
else:
|
|
if isinstance(self, RT) and isinstance(src[0], RV):
|
|
match self.layout:
|
|
case TileLayout.ROW: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0])))
|
|
case TileLayout.COL: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[1], 0])))
|
|
else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx])))
|
|
else: raise NotImplementedError
|
|
return self.ruop(uop)
|
|
def const_like(self, b): return b
|
|
|
|
# override ops that do compute on the src uop
|
|
def sub(self, x, reverse=False):
|
|
return self.ufix(x).alu(Ops.ADD, self, inner_op=lambda y: -y) if reverse else self.alu(Ops.ADD, self.ufix(x), inner_op=lambda y: -y)
|
|
def div(self, x, reverse=False):
|
|
return self.ufix(x).alu(Ops.MUL, self, inner_op=lambda y: 1/y) if reverse else self.alu(Ops.MUL, self.ufix(x), inner_op=lambda y: 1/y)
|
|
|
|
@autowrap(UOp)
|
|
class GL:
|
|
def __init__(self, uop:UOp, ker):
|
|
self._uop, self.ker = uop, ker
|
|
|
|
def ruop(self, uop:UOp):
|
|
return GL(uop, self.ker)
|
|
|
|
@classmethod
|
|
def create(cls, shape, dtype:DType, ker):
|
|
uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
|
|
return cls(uop, ker)
|
|
|
|
class TileLayout(Enum):
|
|
ROW = auto()
|
|
COL = auto()
|
|
|
|
class VecLayout(Enum):
|
|
ORTHO = auto()
|
|
|
|
@dataclass(frozen=True)
|
|
class BaseShape:
|
|
rows: int
|
|
cols: int
|
|
|
|
@property
|
|
def num_elements(self): return self.rows * self.cols
|
|
@property
|
|
def elements_per_thread(self): return self.num_elements // WARP_THREADS
|
|
|
|
@dataclass(frozen=True)
|
|
class STBaseShape(BaseShape):
|
|
_swizzle: Callable[[UOp, DType], UOp]
|
|
bytes_per_thread: Callable[[DType], int]
|
|
|
|
def swizzle(self, row, col, dtype:DType):
|
|
offset = row * self.cols + col
|
|
offset *= dtype.itemsize
|
|
offset = self._swizzle(offset, dtype)
|
|
offset //= dtype.itemsize
|
|
return offset
|
|
|
|
def st_16x16_swizzle(offset:UOp, _): return offset
|
|
def st_16x16_bpt(dtype:DType):
|
|
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
|
else: raise NotImplementedError
|
|
ST_16X16 = STBaseShape(16, 16, st_16x16_swizzle, st_16x16_bpt)
|
|
|
|
def st_16x16_swizzled_swizzle(offset:UOp, dtype:DType):
|
|
if dtype.itemsize == 2:
|
|
swizzle = ((offset % 512) >> 7) << 3
|
|
return offset ^ swizzle
|
|
elif dtype.itemsize == 4:
|
|
return offset
|
|
else: raise NotImplementedError
|
|
def st_16x16_swizzled_bpt(dtype:DType):
|
|
if dtype.itemsize == 2: return 4
|
|
elif dtype.itemsize == 4: return 16
|
|
else: raise NotImplementedError
|
|
ST_16X16_SWIZZLED = STBaseShape(16, 16, st_16x16_swizzled_swizzle, st_16x16_swizzled_bpt)
|
|
|
|
def st_32x32_swizzle(offset:UOp, dtype:DType):
|
|
if dtype.itemsize == 2:
|
|
first_swizzle = ((offset % 1024) >> 9) << 5
|
|
second_swizzle = ((offset % 2048) >> 10) << 4
|
|
return offset ^ first_swizzle ^ second_swizzle
|
|
elif dtype.itemsize == 4:
|
|
return offset
|
|
else: raise NotImplementedError
|
|
def st_32x32_bpt(dtype:DType):
|
|
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
|
else: raise NotImplementedError
|
|
ST_32X32 = STBaseShape(32, 32, st_32x32_swizzle, st_32x32_bpt)
|
|
|
|
def st_16x32_swizzle(offset:UOp, dtype:DType):
|
|
if dtype.itemsize == 2:
|
|
swizzle = ((offset % 1024) >> 9) << 5
|
|
return offset ^ swizzle
|
|
elif dtype.itemsize == 4:
|
|
return offset
|
|
else: raise NotImplementedError
|
|
def st_16x32_bpt(dtype:DType):
|
|
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
|
else: raise NotImplementedError
|
|
ST_16X32 = STBaseShape(16, 32, st_16x32_swizzle, st_16x32_bpt)
|
|
|
|
def st_32x16_swizzle(offset:UOp, dtype:DType):
|
|
if dtype.itemsize == 2:
|
|
swizzle = ((offset % 1024) >> 9) << 4
|
|
return offset ^ swizzle
|
|
elif dtype.itemsize == 4:
|
|
return offset
|
|
else: raise NotImplementedError
|
|
def st_32x16_bpt(dtype:DType):
|
|
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
|
else: raise NotImplementedError
|
|
ST_32X16 = STBaseShape(32, 16, st_32x16_swizzle, st_32x16_bpt)
|
|
|
|
@autowrap(UOp)
|
|
class ST:
|
|
def __init__(self, uop:UOp, rows:int, cols:int, layout:TileLayout, base_shape:STBaseShape, ker):
|
|
self._uop, self.rows, self.cols, self.layout, self.base_shape, self.ker = uop, rows, cols, layout, base_shape, ker
|
|
|
|
def ruop(self, uop:UOp):
|
|
return ST(uop, self.rows, self.cols, self.layout, self.base_shape, self.ker)
|
|
|
|
@classmethod
|
|
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:STBaseShape, ker):
|
|
rows = shape[-2]
|
|
cols = shape[-1]
|
|
assert rows % base_shape.rows == 0
|
|
assert cols % base_shape.cols == 0
|
|
assert cols % base_shape.elements_per_thread == 0
|
|
|
|
height = rows // base_shape.rows
|
|
width = cols // base_shape.cols
|
|
|
|
uop = ker.alloc(shape[:-2] + (height, width, base_shape.rows, base_shape.cols), dtype, AddrSpace.LOCAL)
|
|
return cls(uop, rows, cols, layout, base_shape, ker)
|
|
|
|
def swizzle(self, row, col):
|
|
swizzled_offset = self.base_shape.swizzle(row, col, self._uop.dtype.base.scalar())
|
|
|
|
row = swizzled_offset // self.base_shape.cols
|
|
col = swizzled_offset % self.base_shape.cols
|
|
|
|
return row, col
|
|
|
|
@dataclass(frozen=True)
|
|
class RTBaseShape(BaseShape):
|
|
stride: int
|
|
|
|
@property
|
|
def num_strides(self):
|
|
return self.elements_per_thread // self.stride
|
|
|
|
RT_16X16 = RTBaseShape(rows=16, cols=16, stride=4)
|
|
RT_32X32 = RTBaseShape(rows=32, cols=32, stride=4)
|
|
RT_32X32_8 = RTBaseShape(rows=32, cols=32, stride=8)
|
|
RT_16X32 = RTBaseShape(rows=16, cols=32, stride=8)
|
|
RT_32X16 = RTBaseShape(rows=32, cols=16, stride=8)
|
|
RT_32X16_4 = RTBaseShape(rows=32, cols=16, stride=4)
|
|
RT_16X32_4 = RTBaseShape(rows=16, cols=32, stride=4)
|
|
|
|
@autowrap(UOp)
|
|
class RT(TileMathMixin):
|
|
def __init__(self, uop:UOp, layout:TileLayout, base_shape:RTBaseShape, ker):
|
|
self._uop, self.layout, self.base_shape, self.ker = uop, layout, base_shape, ker
|
|
|
|
def ruop(self, uop:UOp):
|
|
return RT(uop, self.layout, self.base_shape, self.ker)
|
|
|
|
@classmethod
|
|
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:RTBaseShape, ker):
|
|
assert len(shape) == 2
|
|
assert shape[0] % base_shape.rows == 0
|
|
assert shape[1] % base_shape.cols == 0
|
|
|
|
height = shape[0] // base_shape.rows
|
|
width = shape[1] // base_shape.cols
|
|
|
|
uop = ker.alloc((height, width, base_shape.elements_per_thread), dtype, AddrSpace.REG)
|
|
return cls(uop, layout, base_shape, ker)
|
|
|
|
@autowrap(UOp)
|
|
class RV(TileMathMixin):
|
|
def __init__(self, uop:UOp, length:int, layout:VecLayout, base_shape:RTBaseShape, ker):
|
|
self._uop, self.ker = uop, ker
|
|
self.length, self.layout, self.base_shape = length, layout, base_shape
|
|
|
|
def ruop(self, uop:UOp):
|
|
return RV(uop, self.length, self.layout, self.base_shape, self.ker)
|
|
|
|
@classmethod
|
|
def create(cls, length, dtype:DType, layout:VecLayout, base_shape:RTBaseShape, ker):
|
|
tiles = length // base_shape.rows
|
|
|
|
match layout:
|
|
case VecLayout.ORTHO:
|
|
inner_dim = 1
|
|
outer_dim = tiles
|
|
|
|
uop = ker.alloc((outer_dim, inner_dim), dtype, AddrSpace.REG)
|
|
return RV(uop, length, layout, base_shape, ker)
|
|
|
|
ALL_TILES = UOp | GL | ST | RT | RV
|