mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: mixins on tiles (#13246)
This commit is contained in:
@@ -7,7 +7,7 @@ from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.tiles import RT
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, ST, RT, RV
|
||||
|
||||
class Group:
|
||||
def __init__(self, warps:int, ker):
|
||||
@@ -27,7 +27,8 @@ class Group:
|
||||
# ops that only work on a single warp
|
||||
|
||||
clear_rid = 1000
|
||||
def clear(self, reg:UOp, value:float=0):
|
||||
def clear(self, reg:ALL_TILES, value:float=0):
|
||||
reg = cast(UOp, reg)
|
||||
assert self.warps == 1
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
|
||||
@@ -38,11 +39,12 @@ class Group:
|
||||
self.ker.push_store(reg_store, reg)
|
||||
return reg.after(reg_store).reshape(reg.shape)
|
||||
|
||||
def zero(self, reg:UOp): return self.clear(reg, 0)
|
||||
def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf)
|
||||
def zero(self, reg:ALL_TILES): return self.clear(reg, 0)
|
||||
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
|
||||
|
||||
copy_rid = 300
|
||||
def copy(self, dst:UOp, src:UOp):
|
||||
def copy(self, dst:ALL_TILES, src:ALL_TILES):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
assert dst.shape == src.shape
|
||||
|
||||
@@ -54,7 +56,8 @@ class Group:
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
|
||||
def mma_AB(self, c:UOp, a:UOp, b:UOp, after=True):
|
||||
def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True):
|
||||
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
|
||||
assert self.warps == 1
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
@@ -76,7 +79,8 @@ class Group:
|
||||
self.ker.push_store(c_store, c)
|
||||
return c.after(c_store).reshape(c.shape) if after else c_store
|
||||
|
||||
def mma_ABt(self, c:UOp, a:UOp, b:UOp, after=True):
|
||||
def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True):
|
||||
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
|
||||
assert self.warps == 1
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
@@ -99,7 +103,8 @@ class Group:
|
||||
return c.after(c_store).reshape(c.shape) if after else c_store
|
||||
|
||||
map_rid = 400
|
||||
def map(self, a:UOp, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
|
||||
def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
|
||||
a = cast(UOp, a)
|
||||
assert self.warps == 1
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape))
|
||||
@@ -115,7 +120,8 @@ class Group:
|
||||
self.ker.push_store(a_store, a)
|
||||
return a.after(a_store).reshape(a.shape)
|
||||
|
||||
def row_reduce(self, vec:UOp, src:UOp, op:Callable[[UOp, UOp], UOp]):
|
||||
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp]):
|
||||
vec, src = cast(UOp, vec), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
|
||||
red_local = self.ker.alloc((self.group_threads, 2), src.dtype.base, AddrSpace.LOCAL)
|
||||
@@ -157,7 +163,8 @@ class Group:
|
||||
# ops that can work across multiple warps
|
||||
|
||||
LOAD_INNER = 8
|
||||
def load(self, dst:UOp, src:UOp, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
|
||||
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL:
|
||||
@@ -216,7 +223,8 @@ class Group:
|
||||
return dst.after(dst_store.barrier()).reshape(dst.shape)
|
||||
|
||||
STORE_INNER = 8
|
||||
def store(self, dst:UOp, src:UOp, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True):
|
||||
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
|
||||
|
||||
@@ -68,10 +68,10 @@ class Kernel(AbstractContextManager):
|
||||
|
||||
return uop
|
||||
|
||||
def gl(self, shape, dtype): return GL(shape, dtype, self)._uop
|
||||
def st(self, shape, dtype): return ST(shape, dtype, self)._uop
|
||||
def rt(self, shape, dtype): return RT(shape, dtype, self)._uop
|
||||
def rv(self, length, dtype, layout="naive"): return RV(length, dtype, layout, self)._uop
|
||||
def gl(self, shape, dtype): return GL.create(shape, dtype, self)
|
||||
def st(self, shape, dtype): return ST.create(shape, dtype, self)
|
||||
def rt(self, shape, dtype): return RT.create(shape, dtype, self)
|
||||
def rv(self, length, dtype, layout="naive"): return RV.create(length, dtype, layout, self)
|
||||
|
||||
def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop))
|
||||
|
||||
@@ -80,7 +80,7 @@ class Kernel(AbstractContextManager):
|
||||
rngs = []
|
||||
while self.range_stack: rngs.append(self.range_stack.pop(0)._rng)
|
||||
|
||||
return self.store_stack.pop()[0].end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
return self.store_stack.pop()[0]._uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
def endrange(self):
|
||||
last_store = self.store_stack.pop()
|
||||
|
||||
@@ -1,23 +1,114 @@
|
||||
import functools
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.mixin import MathMixin
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
|
||||
def unwrap(x):
|
||||
if hasattr(x, "_uop"): return x._uop
|
||||
if isinstance(x, (list, tuple)): return type(x)(unwrap(y) for y in x)
|
||||
if isinstance(x, dict): return {k: unwrap(v) for k,v in x.items()}
|
||||
return x
|
||||
|
||||
def wrap(x, ker, cls):
|
||||
if isinstance(x, UOp): return cls(x, ker)
|
||||
if isinstance(x, (list, tuple)): return type(x)(wrap(y, ker, cls) for y in x)
|
||||
return x
|
||||
|
||||
def autowrap(source_cls, blacklist=None):
|
||||
if blacklist is None:
|
||||
blacklist = {
|
||||
"__init__", "__new__", "__str__", "__del__", "__repr__", "__dict__", "__getattribute__",
|
||||
"__setattr__", "__delattr__", "__weakref__", "__slots__", "__class__",
|
||||
"__reduce__", "__reduce_ex__", "__getstate__", "__setstate__", "__hash__"
|
||||
}
|
||||
|
||||
def decorator(cls):
|
||||
def __getattr__(self, name):
|
||||
uop = object.__getattribute__(self, "_uop")
|
||||
val = getattr(uop, name)
|
||||
if callable(val):
|
||||
@functools.wraps(val)
|
||||
def proxy(*args, **kwargs):
|
||||
return wrap(val(*unwrap(args), **unwrap(kwargs)), self.ker, cls)
|
||||
return proxy
|
||||
if name in UOp.__slots__: return val
|
||||
return wrap(val, self.ker, cls)
|
||||
cls.__getattr__ = __getattr__
|
||||
|
||||
for name in dir(source_cls):
|
||||
if name in blacklist or not name.startswith("__"): continue
|
||||
|
||||
for base in cls.mro():
|
||||
if base is source_cls: break
|
||||
if name in base.__dict__: break
|
||||
else:
|
||||
original = getattr(source_cls, name)
|
||||
if callable(original):
|
||||
def make_proxy(op_name, func):
|
||||
def proxy(self, *args, **kwargs):
|
||||
return wrap(func(self._uop, *unwrap(args), **unwrap(kwargs)), self.ker, cls)
|
||||
return proxy
|
||||
setattr(cls, name, make_proxy(name, original))
|
||||
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
class TileMathMixin(MathMixin):
|
||||
def alu(self, op, *src, inner_op=lambda x:x):
|
||||
assert isinstance(self, (RT, RV))
|
||||
if len(src) == 0:
|
||||
if self._uop._shape is None: uop = UOp.alu(self._uop, op)
|
||||
else: uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op))
|
||||
elif len(src) == 1:
|
||||
if self._uop._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
|
||||
elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0]))))
|
||||
elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
|
||||
else:
|
||||
if isinstance(self, RT) and isinstance(src[0], RV): uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0, (idx[2]%4)//2])))
|
||||
else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx])))
|
||||
else: raise NotImplementedError
|
||||
return type(self)(uop, self.ker)
|
||||
def const_like(self, b): return b
|
||||
|
||||
# override ops that do compute on the src uop
|
||||
def sub(self, x, reverse=False):
|
||||
return self.ufix(x).alu(Ops.ADD, self, inner_op=lambda y: -y) if reverse else self.alu(Ops.ADD, self.ufix(x), inner_op=lambda y: -y)
|
||||
def div(self, x, reverse=False):
|
||||
return self.ufix(x).alu(Ops.MUL, self, inner_op=lambda y: 1/y) if reverse else self.alu(Ops.MUL, self.ufix(x), inner_op=lambda y: 1/y)
|
||||
|
||||
@autowrap(UOp)
|
||||
class GL:
|
||||
def __init__(self, shape, dtype, ker):
|
||||
self.shape, self.dtype = shape, dtype
|
||||
self._uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
|
||||
def __init__(self, uop, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
|
||||
return cls(uop, ker)
|
||||
|
||||
@autowrap(UOp)
|
||||
class ST:
|
||||
def __init__(self, shape, dtype, ker):
|
||||
self.shape, self.dtype = shape, dtype
|
||||
self._uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
|
||||
def __init__(self, uop, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
|
||||
class RT:
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
|
||||
return cls(uop, ker)
|
||||
|
||||
@autowrap(UOp)
|
||||
class RT(TileMathMixin):
|
||||
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
|
||||
BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
|
||||
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
|
||||
|
||||
def __init__(self, shape, dtype, ker):
|
||||
def __init__(self, uop, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
assert len(shape) == 2
|
||||
assert shape[0] % RT.TILE_ROW_DIM == 0
|
||||
assert shape[1] % RT.TILE_COL_DIM == 0
|
||||
@@ -25,11 +116,16 @@ class RT:
|
||||
height = shape[0] // RT.TILE_ROW_DIM
|
||||
width = shape[1] // RT.TILE_COL_DIM
|
||||
|
||||
self.shape, self.dtype = (height, width, self.BASE_TILE_NEPT), dtype
|
||||
self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG)
|
||||
uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG)
|
||||
return cls(uop, ker)
|
||||
|
||||
class RV:
|
||||
def __init__(self, length, dtype, layout, ker):
|
||||
@autowrap(UOp)
|
||||
class RV(TileMathMixin):
|
||||
def __init__(self, uop, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
|
||||
@classmethod
|
||||
def create(cls, length, dtype, layout, ker):
|
||||
tiles = length // RT.TILE_ROW_DIM
|
||||
|
||||
match layout:
|
||||
@@ -41,5 +137,7 @@ class RV:
|
||||
outer_dim = tiles
|
||||
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
|
||||
|
||||
self.shape, self.dtype = (outer_dim, inner_dim, 2), dtype
|
||||
self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG)
|
||||
uop = ker.alloc((outer_dim, inner_dim, 2), dtype, AddrSpace.REG)
|
||||
return RV(uop, ker)
|
||||
|
||||
ALL_TILES = UOp | GL | ST | RT | RV
|
||||
|
||||
Reference in New Issue
Block a user