Files
tinygrad/extra/thunder/tiny/tk/group.py
2025-11-25 15:49:44 -08:00

422 lines
20 KiB
Python

import math, functools
from typing import cast, Callable
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.helpers import getenv, prod
from extra.thunder.tiny.tk import WARP_THREADS
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout
class Group:
def __init__(self, warps:int, ker):
self.warps = warps
self.group_threads = warps * WARP_THREADS
self.ker = ker
# helpers
@property
def laneid(self): return self.ker.threadIdx_x % self.group_threads
@property
def warpid(self): return self.laneid // WARP_THREADS
@property
def groupid(self): return self.ker.threadIdx_x // self.group_threads
# ops that only work on a single warp
clear_rid = 1000
def clear(self, reg:ALL_TILES, value:float=0):
reg = cast(UOp, reg)
assert self.warps == 1
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
Group.clear_rid += len(reg.shape)
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
self.ker.push_store(reg_store, reg)
return reg.after(reg_store).reshape(reg.shape)
def zero(self, reg:ALL_TILES): return self.clear(reg, 0)
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
copy_rid = 300
def copy(self, dst:ALL_TILES, src:ALL_TILES):
dst, src = cast(UOp, dst), cast(UOp, src)
assert self.warps == 1
assert dst.shape == src.shape
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
Group.copy_rid += len(dst.shape)
src_load = src[*rngs_for_shape]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*rngs_for_shape].store(src_load).end(*rngs_for_shape)
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
def transpose(self, dst:UOp|RT, src:UOp|RT):
dst, src = cast(UOp, dst), cast(UOp, src)
assert self.warps == 1
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(src.shape[-1], track=False):
dst_store = dst[width, height, inner].store(src[height, width, inner]).end(height, width, inner)
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape)
def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape)
def mma_AtB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape)
def mma_AtBt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape)
map_rid = 400
def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
a = cast(UOp, a)
assert self.warps == 1
rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape))
Group.map_rid += len(a.shape)
if op.__code__.co_argcount == 1:
to_store = op(a[*rngs_for_shape])
else:
to_store = op(a[*rngs_for_shape], rngs_for_shape)
a_store = a[*rngs_for_shape].store(to_store).end(*rngs_for_shape)
self.ker.push_store(a_store, a)
return a.after(a_store).reshape(a.shape)
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp], init_value:float=0.0):
vec, src = cast(UOp, vec), cast(UOp, src)
assert self.warps == 1
red_local = self.ker.alloc((self.group_threads,), src.dtype.base, AddrSpace.LOCAL)
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
for height in self.ker.range(src.shape[-3], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(init_value).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
for width in self.ker.range(src.shape[-2], axis_type=AxisType.REDUCE, track=False):
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
reg_store = red_reg[0].store(op(red_reg[0], src[height, width, inner])).end(width, inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# store to shared memory
red_local_store = red_local[self.laneid].store(red_reg[0])
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
# reduce from shared memory
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
offset = (self.laneid + (1 + inner) * 16) % self.group_threads
reg_store = red_reg[0].store(op(red_reg[0], red_local[offset])).end(inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# reduce with vec
vec_store = vec[height, 0].store(op(vec[height, 0], red_reg[0])).end(height)
self.ker.push_store(vec_store, vec)
return vec.after(vec_store).reshape(vec.shape)
def col_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp], init_value:float=0.0):
vec, src = cast(UOp, vec), cast(UOp, src)
assert self.warps == 1
red_local = self.ker.alloc((self.group_threads,), src.dtype.base, AddrSpace.LOCAL)
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
for width in self.ker.range(src.shape[-2], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(init_value).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
for height in self.ker.range(src.shape[-3], axis_type=AxisType.REDUCE, track=False):
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
reg_store = red_reg[0].store(op(red_reg[0], src[height, width, inner])).end(height, inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# store to shared memory
red_local_store = red_local[self.laneid].store(red_reg[0])
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
# reduce from shared memory
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
offset = (self.laneid + (1 + inner) * 16) % self.group_threads
reg_store = red_reg[0].store(op(red_reg[0], red_local[offset])).end(inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# reduce with vec
vec_store = vec[width, 0].store(op(vec[width, 0], red_reg[0])).end(width)
self.ker.push_store(vec_store, vec)
return vec.after(vec_store).reshape(vec.shape)
# ops that can work across multiple warps
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0):
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL:
laneid = self.ker.laneid
rt, st = cast(RT, dst), cast(ST, src)
elements_per_thread = rt.base_shape.elements_per_thread
for height in self.ker.range(dst.shape[-3], track=False):
for width in self.ker.range(dst.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
if rt.layout != st.layout:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
srow, scol = cast(ST, src).swizzle(row, col)
src_load = src[*idxs[:-2], height, width, srow, scol]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, inner].store(src_load)
dst_store = dst_store.end(height, width, inner)
elif dst_dtype.addrspace == AddrSpace.LOCAL and src_dtype.addrspace == AddrSpace.GLOBAL:
srcf = src.flatten()
row_stride = prod(src.shape[axis+1:])
st = cast(ST, dst)
idxs = tuple(idx * st.rows if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * st.cols if i == 3 else idx for i, idx in enumerate(idxs))
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
for height in self.ker.range(dst.shape[-4], track=False):
for width in self.ker.range(dst.shape[-3], track=False):
elements_per_thread = st.base_shape.elements_per_thread
memcpy_per_row = st.base_shape.cols // elements_per_thread
total_calls = st.base_shape.num_elements // (self.group_threads * elements_per_thread)
for outer in self.ker.range(total_calls, track=False):
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
load_idx = outer * self.group_threads + self.laneid
row = load_idx // memcpy_per_row
col = (load_idx * elements_per_thread) % st.base_shape.cols + inner
srow, scol = cast(ST, dst).swizzle(row, col)
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
src_i += row * row_stride + col
src_load = srcf[src_i]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
dst_store = dst_store.end(height, width, outer, inner).barrier()
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace ==AddrSpace.GLOBAL:
srcf = src.flatten()
row_stride = prod(src.shape[axis+1:])
laneid = self.ker.laneid
rt = cast(RT, dst)
elements_per_thread = rt.base_shape.elements_per_thread
idxs = tuple(idx * dst.shape[-3] * rt.base_shape.rows if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * dst.shape[-2] * rt.base_shape.cols if i == 3 else idx for i, idx in enumerate(idxs))
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
for height in self.ker.range(dst.shape[-3], track=False):
for width in self.ker.range(dst.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
base_row = height * rt.base_shape.rows
base_col = width * rt.base_shape.cols
if rt.layout == TileLayout.COL:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
srow, scol = base_row + row, base_col + col
src_i += srow * row_stride + scol
src_load = srcf[src_i]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, inner].store(src_load).end(height, width, inner)
else:
raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis:int=0):
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL:
dstf = dst.flatten()
row_stride = prod(dst.shape[axis+1:])
laneid = self.ker.laneid
rt = cast(RT, src)
elements_per_thread = rt.base_shape.elements_per_thread
idxs = tuple(idx * src.shape[-3] * rt.base_shape.rows if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * src.shape[-2] * rt.base_shape.cols if i == 3 else idx for i, idx in enumerate(idxs))
dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3]
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
base_row = height * rt.base_shape.rows
base_col = width * rt.base_shape.cols
if rt.layout == TileLayout.COL:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
srow, scol = base_row + row, base_col + col
dst_i += srow * row_stride + scol
src_load = src[*src_idxs, height, width, inner]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dstf[dst_i].store(src_load).end(height, width, inner)
else:
raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)