This commit is contained in:
wozeparrot
2025-11-25 15:49:44 -08:00
committed by GitHub
parent 436ab6bfc7
commit ffc31a23f4
5 changed files with 873 additions and 275 deletions

View File

@@ -7,22 +7,21 @@ from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.helpers import getenv, prod
from extra.thunder.tiny.tk import WARP_THREADS
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, ST, RT, RV
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout
class Group:
def __init__(self, warps:int, ker):
self.warps = warps
self.group_threads = warps * WARP_THREADS
self.threadIdx_x = ker.threadIdx_x
self.ker = ker
# helpers
@property
def laneid(self): return self.threadIdx_x % self.group_threads
def laneid(self): return self.ker.threadIdx_x % self.group_threads
@property
def warpid(self): return self.laneid // WARP_THREADS
@property
def groupid(self): return self.threadIdx_x // self.group_threads
def groupid(self): return self.ker.threadIdx_x // self.group_threads
# ops that only work on a single warp
@@ -40,6 +39,7 @@ class Group:
return reg.after(reg_store).reshape(reg.shape)
def zero(self, reg:ALL_TILES): return self.clear(reg, 0)
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
copy_rid = 300
@@ -51,7 +51,22 @@ class Group:
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
Group.copy_rid += len(dst.shape)
dst_store = dst[*rngs_for_shape].store(src[*rngs_for_shape].cast(dst.dtype.base)).end(*rngs_for_shape)
src_load = src[*rngs_for_shape]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*rngs_for_shape].store(src_load).end(*rngs_for_shape)
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
def transpose(self, dst:UOp|RT, src:UOp|RT):
dst, src = cast(UOp, dst), cast(UOp, src)
assert self.warps == 1
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(src.shape[-1], track=False):
dst_store = dst[width, height, inner].store(src[height, width, inner]).end(height, width, inner)
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
@@ -60,20 +75,27 @@ class Group:
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
b_in1 = UOp.vectorize(*([b[inner, width, i] for i in range(2)] + [b[inner, width, 4+i] for i in range(2)]))
c_out1 = UOp.vectorize(*[c[height, width, i] for i in range(4)])
b_in2 = UOp.vectorize(*([b[inner, width, 2+i] for i in range(2)] + [b[inner, width, 6+i] for i in range(2)]))
c_out2 = UOp.vectorize(*[c[height, width, 4+i] for i in range(4)])
out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg)
out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg)
c_i = [c[height, width, i].store(out1.gep(i)) for i in range(4)] + [c[height, width, 4+i].store(out2.gep(i)) for i in range(4)]
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
@@ -83,20 +105,87 @@ class Group:
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
b_in1 = UOp.vectorize(*([b[width, inner, i] for i in range(2)] + [b[width, inner, 4+i] for i in range(2)]))
c_out1 = UOp.vectorize(*[c[height, width, i] for i in range(4)])
b_in2 = UOp.vectorize(*([b[width, inner, 2+i] for i in range(2)] + [b[width, inner, 6+i] for i in range(2)]))
c_out2 = UOp.vectorize(*[c[height, width, 4+i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg)
out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg)
c_i = [c[height, width, i].store(out1.gep(i)) for i in range(4)] + [c[height, width, 4+i].store(out2.gep(i)) for i in range(4)]
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape)
def mma_AtB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape)
def mma_AtBt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
a_base_shape = cast(RT, a).base_shape
if a_base_shape.cols == 16:
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
elif a_base_shape.cols == 32:
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
for height in self.ker.range(c.shape[-3], track=False):
for width in self.ker.range(c.shape[-2], track=False):
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
if a_base_shape.cols == 16:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
elif a_base_shape.cols == 32:
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
c_store = UOp.group(*c_i).end(height, width, inner)
self.ker.push_store(c_store, c)
@@ -120,171 +209,213 @@ class Group:
self.ker.push_store(a_store, a)
return a.after(a_store).reshape(a.shape)
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp]):
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp], init_value:float=0.0):
vec, src = cast(UOp, vec), cast(UOp, src)
assert self.warps == 1
red_local = self.ker.alloc((self.group_threads, 2), src.dtype.base, AddrSpace.LOCAL)
red_reg = self.ker.alloc((2,), src.dtype.base, AddrSpace.REG)
red_local = self.ker.alloc((self.group_threads,), src.dtype.base, AddrSpace.LOCAL)
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
for height in self.ker.range(src.shape[-3], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(0.).end(i)
reg_store = red_reg.flatten()[i].store(init_value).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
for outer in self.ker.range(2, track=False):
for width in self.ker.range(src.shape[-2], axis_type=AxisType.REDUCE, track=False):
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
elem_index = inner + 2 * (inner // 2) + outer * 2
reg_store = red_reg[outer].store(op(red_reg[outer], src[height, width, elem_index])).end(inner, width, outer)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# store to shared memory
for outer in self.ker.range(2, track=False):
red_local_store = red_local[self.laneid, outer].store(red_reg[outer]).end(outer)
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
# reduce from shared memory
for outer in self.ker.range(2, track=False):
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
offset = (self.laneid // 4) * 4 + ((self.laneid + inner + 1) % 4)
reg_store = red_reg[outer].store(op(red_reg[outer], red_local[offset, outer])).end(inner, outer)
for width in self.ker.range(src.shape[-2], axis_type=AxisType.REDUCE, track=False):
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
reg_store = red_reg[0].store(op(red_reg[0], src[height, width, inner])).end(width, inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# store to shared memory
red_local_store = red_local[self.laneid].store(red_reg[0])
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
# reduce from shared memory
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
offset = (self.laneid + (1 + inner) * 16) % self.group_threads
reg_store = red_reg[0].store(op(red_reg[0], red_local[offset])).end(inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# reduce with vec
for outer in self.ker.range(2, track=False):
vec_store = vec[height, 0, outer].store(op(vec[height, 0, outer], red_reg[outer])).end(outer, height)
vec_store = vec[height, 0].store(op(vec[height, 0], red_reg[0])).end(height)
self.ker.push_store(vec_store, vec)
return vec.after(vec_store).reshape(vec.shape)
def col_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp], init_value:float=0.0):
vec, src = cast(UOp, vec), cast(UOp, src)
assert self.warps == 1
red_local = self.ker.alloc((self.group_threads,), src.dtype.base, AddrSpace.LOCAL)
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
for width in self.ker.range(src.shape[-2], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(init_value).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
for height in self.ker.range(src.shape[-3], axis_type=AxisType.REDUCE, track=False):
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
reg_store = red_reg[0].store(op(red_reg[0], src[height, width, inner])).end(height, inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# store to shared memory
red_local_store = red_local[self.laneid].store(red_reg[0])
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
# reduce from shared memory
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
offset = (self.laneid + (1 + inner) * 16) % self.group_threads
reg_store = red_reg[0].store(op(red_reg[0], red_local[offset])).end(inner)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
# reduce with vec
vec_store = vec[width, 0].store(op(vec[width, 0], red_reg[0])).end(width)
self.ker.push_store(vec_store, vec)
return vec.after(vec_store).reshape(vec.shape)
# ops that can work across multiple warps
LOAD_INNER = 4
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0):
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL:
srcf = src.flatten(-2)
if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4)
else: local_warpid = self.warpid
warp_laneid = self.threadIdx_x % WARP_THREADS
laneid = self.ker.laneid
rt, st = cast(RT, dst), cast(ST, src)
elements_per_thread = rt.base_shape.elements_per_thread
for height in self.ker.range(dst.shape[-3], track=False):
for width in self.ker.range(dst.shape[-2], track=False):
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
base_row = (local_warpid * dst.shape[-3] + height) * RT.BASE_TILE_ROWS
base_col = width * RT.BASE_TILE_COLS
if not transpose:
row = base_row + (warp_laneid // 4)
col = base_col + 2 * (warp_laneid % 4)
row_offset = ((inner % 4) // 2) * 8
col_offset = (inner % 2) + (inner // 4) * 8
for inner in self.ker.range(elements_per_thread, track=False):
if rt.layout != st.layout:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = base_row + 2 * (warp_laneid % 4)
col = base_col + (warp_laneid // 4)
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
row_offset = (inner % 2) + (inner // 4) * 8
col_offset = ((inner % 4) // 2) * 8
srow, scol = cast(ST, src).swizzle(row, col)
src_i_last = (row + row_offset) * src.shape[-1] + col + col_offset
dst_store = dst[*dst_idxs, height, width, inner].store(srcf[*idxs[:-2], src_i_last])
src_load = src[*idxs[:-2], height, width, srow, scol]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, inner].store(src_load)
dst_store = dst_store.end(height, width, inner)
elif dst_dtype.addrspace == AddrSpace.LOCAL and src_dtype.addrspace == AddrSpace.GLOBAL:
dstf = dst.flatten(-2)
srcf = src.flatten()
row_stride = prod(src.shape[axis+1:])
idxs = tuple(idx * dst.shape[-2] if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * dst.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs))
st = cast(ST, dst)
idxs = tuple(idx * st.rows if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * st.cols if i == 3 else idx for i, idx in enumerate(idxs))
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
memcpy_per_row = dst.shape[-1] // Group.LOAD_INNER
total_calls = prod(dst.shape[-2:]) // (self.group_threads * Group.LOAD_INNER)
for height in self.ker.range(dst.shape[-4], track=False):
for width in self.ker.range(dst.shape[-3], track=False):
elements_per_thread = st.base_shape.elements_per_thread
memcpy_per_row = st.base_shape.cols // elements_per_thread
total_calls = st.base_shape.num_elements // (self.group_threads * elements_per_thread)
for outer in self.ker.range(total_calls, track=False):
for inner in self.ker.range(Group.LOAD_INNER, track=False):
load_idx = outer * self.group_threads + self.laneid
row = load_idx // memcpy_per_row
col = (load_idx * Group.LOAD_INNER) % dst.shape[-1]
for outer in self.ker.range(total_calls, track=False):
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
load_idx = outer * self.group_threads + self.laneid
row = load_idx // memcpy_per_row
col = (load_idx * elements_per_thread) % st.base_shape.cols + inner
dst_i = row * dst.shape[-1] + col + inner
src_i += row * row_stride + col + inner
srow, scol = cast(ST, dst).swizzle(row, col)
dst_store = dstf[*dst_idxs, dst_i].store(srcf[src_i]).end(outer, inner)
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
src_i += row * row_stride + col
src_load = srcf[src_i]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
dst_store = dst_store.end(height, width, outer, inner).barrier()
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace ==AddrSpace.GLOBAL:
srcf = src.flatten()
row_stride = prod(src.shape[axis+1:])
laneid = self.ker.laneid
rt = cast(RT, dst)
elements_per_thread = rt.base_shape.elements_per_thread
idxs = tuple(idx * dst.shape[-3] * rt.base_shape.rows if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * dst.shape[-2] * rt.base_shape.cols if i == 3 else idx for i, idx in enumerate(idxs))
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
for height in self.ker.range(dst.shape[-3], track=False):
for width in self.ker.range(dst.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
base_row = height * rt.base_shape.rows
base_col = width * rt.base_shape.cols
if rt.layout == TileLayout.COL:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
srow, scol = base_row + row, base_col + col
src_i += srow * row_stride + scol
src_load = srcf[src_i]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, inner].store(src_load).end(height, width, inner)
else:
raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
return dst.after(dst_store.barrier()).reshape(dst.shape)
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
STORE_INNER = 4
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis:int=0):
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
dstf = dst.flatten(-2)
if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4)
else: local_warpid = self.warpid
warp_laneid = self.threadIdx_x % WARP_THREADS
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
base_row = (local_warpid * src.shape[-3] + height) * RT.BASE_TILE_ROWS
base_col = width * RT.BASE_TILE_COLS
if not transpose:
row = base_row + (warp_laneid // 4)
col = base_col + 2 * (warp_laneid % 4)
row_offset = ((inner % 4) // 2) * 8
col_offset = (inner % 2) + (inner // 4) * 8
else:
row = base_row + 2 * (warp_laneid % 4)
col = base_col + (warp_laneid // 4)
row_offset = (inner % 2) + (inner // 4) * 8
col_offset = ((inner % 4) // 2) * 8
dst_i_last = (row + row_offset) * dst.shape[-1] + col + col_offset
dst_store = dstf[*idxs[:-2], dst_i_last].store(src[*src_idxs, height, width, inner])
dst_store = dst_store.end(height, width, inner)
elif src_dtype.addrspace == AddrSpace.LOCAL and dst_dtype.addrspace == AddrSpace.GLOBAL:
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL:
dstf = dst.flatten()
row_stride = prod(dst.shape[axis+1:])
idxs = tuple(idx * src.shape[-2] if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * src.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs))
laneid = self.ker.laneid
rt = cast(RT, src)
elements_per_thread = rt.base_shape.elements_per_thread
idxs = tuple(idx * src.shape[-3] * rt.base_shape.rows if i == axis else idx for i, idx in enumerate(idxs))
idxs = tuple(idx * src.shape[-2] * rt.base_shape.cols if i == 3 else idx for i, idx in enumerate(idxs))
dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3]
srcf = src.flatten(-2)
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
base_row = height * rt.base_shape.rows
base_col = width * rt.base_shape.cols
memcpy_per_row = src.shape[-1] // Group.STORE_INNER
total_calls = prod(src.shape[-2:]) // (self.group_threads * Group.STORE_INNER)
if rt.layout == TileLayout.COL:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
for outer in self.ker.range(total_calls, track=False):
for inner in self.ker.range(Group.STORE_INNER, track=False):
load_idx = outer * self.group_threads + self.laneid
row = load_idx // memcpy_per_row
col = (load_idx * Group.STORE_INNER) % src.shape[-1]
srow, scol = base_row + row, base_col + col
src_i = row * src.shape[-1] + col + inner
dst_i += row * row_stride + col + inner
dst_i += srow * row_stride + scol
dst_store = dstf[dst_i].store(srcf[*src_idxs, src_i]).end(outer, inner)
src_load = src[*src_idxs, height, width, inner]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dstf[dst_i].store(src_load).end(height, width, inner)
else:
raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
self.ker.push_store(dst_store, dst)
return dst.after(dst_store.barrier()).reshape(dst.shape)
return dst.after(dst_store).reshape(dst.shape)

View File

@@ -2,7 +2,7 @@ from contextlib import AbstractContextManager
from tinygrad.uop.ops import UOp, KernelInfo, AxisType, AddrSpace
from extra.thunder.tiny.tk import WARP_THREADS
from extra.thunder.tiny.tk.group import Group
from extra.thunder.tiny.tk.tiles import GL, ST, RT, RV
from extra.thunder.tiny.tk.tiles import GL, ST_16X16, ST_16X16_SWIZZLED, ST, RT_16X16, RT, RV, TileLayout, VecLayout
class _tk_range:
user_rid = 0
@@ -35,6 +35,8 @@ class Kernel(AbstractContextManager):
@property
def warpid(self): return self.threadIdx_x // WARP_THREADS
@property
def laneid(self): return self.threadIdx_x % WARP_THREADS
def __enter__(self): return self
def __exit__(self, exc_type, exc_value, traceback): pass
@@ -72,9 +74,9 @@ class Kernel(AbstractContextManager):
return uop
def gl(self, shape, dtype): return GL.create(shape, dtype, self)
def st(self, shape, dtype): return ST.create(shape, dtype, self)
def rt(self, shape, dtype): return RT.create(shape, dtype, self)
def rv(self, length, dtype, layout="naive"): return RV.create(length, dtype, layout, self)
def st(self, shape, dtype, layout=TileLayout.ROW, base_shape=ST_16X16): return ST.create(shape, dtype, layout, base_shape, self)
def rt(self, shape, dtype, layout=TileLayout.ROW, base_shape=RT_16X16): return RT.create(shape, dtype, layout, base_shape, self)
def rv(self, length, dtype, layout=VecLayout.ORTHO, rt_base_shape=RT_16X16): return RV.create(length, dtype, layout, rt_base_shape, self)
def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop))
@@ -92,4 +94,4 @@ class Kernel(AbstractContextManager):
def endrange(self):
last_store = self.store_stack.pop()
last_range = self.range_stack.pop()
return last_store[1].after(last_store[0].barrier().end(last_range._rng)).reshape(last_store[1].shape)
return last_store[1].after(last_store[0].end(last_range._rng)).reshape(last_store[1].shape)

View File

@@ -1,5 +1,8 @@
from enum import Enum, auto
import functools
from tinygrad.dtype import AddrSpace
from typing import Callable
from dataclasses import dataclass
from tinygrad.dtype import AddrSpace, DType
from tinygrad.mixin import MathMixin
from tinygrad.uop.ops import UOp, Ops
@@ -66,7 +69,10 @@ class TileMathMixin(MathMixin):
elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0]))))
elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
else:
if isinstance(self, RT) and isinstance(src[0], RV): uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0, (idx[2]%4)//2])))
if isinstance(self, RT) and isinstance(src[0], RV):
match self.layout:
case TileLayout.ROW: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0])))
case TileLayout.COL: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[1], 0])))
else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx])))
else: raise NotImplementedError
return self.ruop(uop)
@@ -80,76 +86,186 @@ class TileMathMixin(MathMixin):
@autowrap(UOp)
class GL:
def __init__(self, uop, ker):
def __init__(self, uop:UOp, ker):
self._uop, self.ker = uop, ker
def ruop(self, uop):
def ruop(self, uop:UOp):
return GL(uop, self.ker)
@classmethod
def create(cls, shape, dtype, ker):
def create(cls, shape, dtype:DType, ker):
uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
return cls(uop, ker)
class TileLayout(Enum):
ROW = auto()
COL = auto()
class VecLayout(Enum):
ORTHO = auto()
@dataclass(frozen=True)
class BaseShape:
rows: int
cols: int
@property
def num_elements(self): return self.rows * self.cols
@property
def elements_per_thread(self): return self.num_elements // WARP_THREADS
@dataclass(frozen=True)
class STBaseShape(BaseShape):
_swizzle: Callable[[UOp, DType], UOp]
bytes_per_thread: Callable[[DType], int]
def swizzle(self, row, col, dtype:DType):
offset = row * self.cols + col
offset *= dtype.itemsize
offset = self._swizzle(offset, dtype)
offset //= dtype.itemsize
return offset
def st_16x16_swizzle(offset:UOp, _): return offset
def st_16x16_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_16X16 = STBaseShape(16, 16, st_16x16_swizzle, st_16x16_bpt)
def st_16x16_swizzled_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
swizzle = ((offset % 512) >> 7) << 3
return offset ^ swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_16x16_swizzled_bpt(dtype:DType):
if dtype.itemsize == 2: return 4
elif dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_16X16_SWIZZLED = STBaseShape(16, 16, st_16x16_swizzled_swizzle, st_16x16_swizzled_bpt)
def st_32x32_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
first_swizzle = ((offset % 1024) >> 9) << 5
second_swizzle = ((offset % 2048) >> 10) << 4
return offset ^ first_swizzle ^ second_swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_32x32_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_32X32 = STBaseShape(32, 32, st_32x32_swizzle, st_32x32_bpt)
def st_16x32_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
swizzle = ((offset % 1024) >> 9) << 5
return offset ^ swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_16x32_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_16X32 = STBaseShape(16, 32, st_16x32_swizzle, st_16x32_bpt)
def st_32x16_swizzle(offset:UOp, dtype:DType):
if dtype.itemsize == 2:
swizzle = ((offset % 1024) >> 9) << 4
return offset ^ swizzle
elif dtype.itemsize == 4:
return offset
else: raise NotImplementedError
def st_32x16_bpt(dtype:DType):
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
else: raise NotImplementedError
ST_32X16 = STBaseShape(32, 16, st_32x16_swizzle, st_32x16_bpt)
@autowrap(UOp)
class ST:
def __init__(self, uop, ker):
self._uop, self.ker = uop, ker
def __init__(self, uop:UOp, rows:int, cols:int, layout:TileLayout, base_shape:STBaseShape, ker):
self._uop, self.rows, self.cols, self.layout, self.base_shape, self.ker = uop, rows, cols, layout, base_shape, ker
def ruop(self, uop):
return ST(uop, self.ker)
def ruop(self, uop:UOp):
return ST(uop, self.rows, self.cols, self.layout, self.base_shape, self.ker)
@classmethod
def create(cls, shape, dtype, ker):
uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
return cls(uop, ker)
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:STBaseShape, ker):
rows = shape[-2]
cols = shape[-1]
assert rows % base_shape.rows == 0
assert cols % base_shape.cols == 0
assert cols % base_shape.elements_per_thread == 0
height = rows // base_shape.rows
width = cols // base_shape.cols
uop = ker.alloc(shape[:-2] + (height, width, base_shape.rows, base_shape.cols), dtype, AddrSpace.LOCAL)
return cls(uop, rows, cols, layout, base_shape, ker)
def swizzle(self, row, col):
swizzled_offset = self.base_shape.swizzle(row, col, self._uop.dtype.base.scalar())
row = swizzled_offset // self.base_shape.cols
col = swizzled_offset % self.base_shape.cols
return row, col
@dataclass(frozen=True)
class RTBaseShape(BaseShape):
stride: int
@property
def num_strides(self):
return self.elements_per_thread // self.stride
RT_16X16 = RTBaseShape(rows=16, cols=16, stride=4)
RT_32X32 = RTBaseShape(rows=32, cols=32, stride=4)
RT_32X32_8 = RTBaseShape(rows=32, cols=32, stride=8)
RT_16X32 = RTBaseShape(rows=16, cols=32, stride=8)
RT_32X16 = RTBaseShape(rows=32, cols=16, stride=8)
RT_32X16_4 = RTBaseShape(rows=32, cols=16, stride=4)
RT_16X32_4 = RTBaseShape(rows=16, cols=32, stride=4)
@autowrap(UOp)
class RT(TileMathMixin):
BASE_TILE_ROWS, BASE_TILE_COLS = 16, 16
BASE_TILE_NE = BASE_TILE_ROWS * BASE_TILE_COLS
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
def __init__(self, uop:UOp, layout:TileLayout, base_shape:RTBaseShape, ker):
self._uop, self.layout, self.base_shape, self.ker = uop, layout, base_shape, ker
def __init__(self, uop, ker):
self._uop, self.ker = uop, ker
def ruop(self, uop):
return RT(uop, self.ker)
def ruop(self, uop:UOp):
return RT(uop, self.layout, self.base_shape, self.ker)
@classmethod
def create(cls, shape, dtype, ker):
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:RTBaseShape, ker):
assert len(shape) == 2
assert shape[0] % RT.BASE_TILE_ROWS == 0
assert shape[1] % RT.BASE_TILE_COLS == 0
assert shape[0] % base_shape.rows == 0
assert shape[1] % base_shape.cols == 0
height = shape[0] // RT.BASE_TILE_ROWS
width = shape[1] // RT.BASE_TILE_COLS
height = shape[0] // base_shape.rows
width = shape[1] // base_shape.cols
uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG)
return cls(uop, ker)
uop = ker.alloc((height, width, base_shape.elements_per_thread), dtype, AddrSpace.REG)
return cls(uop, layout, base_shape, ker)
@autowrap(UOp)
class RV(TileMathMixin):
def __init__(self, uop, layout, ker):
def __init__(self, uop:UOp, layout:VecLayout, ker):
self._uop, self.layout, self.ker = uop, layout, ker
def ruop(self, uop):
def ruop(self, uop:UOp):
return RV(uop, self.layout, self.ker)
@classmethod
def create(cls, length, dtype, layout, ker):
tiles = length // RT.BASE_TILE_ROWS
def create(cls, length, dtype:DType, layout:VecLayout, base_shape:RTBaseShape, ker):
tiles = length // base_shape.rows
match layout:
case "naive":
inner_dim = 1
outer_dim = (tiles + 1) // 2
case "ortho":
case VecLayout.ORTHO:
inner_dim = 1
outer_dim = tiles
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
uop = ker.alloc((outer_dim, inner_dim, 2), dtype, AddrSpace.REG)
uop = ker.alloc((outer_dim, inner_dim), dtype, AddrSpace.REG)
return RV(uop, layout, ker)
ALL_TILES = UOp | GL | ST | RT | RV

View File

@@ -0,0 +1,156 @@
from tinygrad.helpers import colored
WARP_THREADS = 64
BASE_TILE_ROWS = 16
BASE_TILE_COLS = 16
BASE_TILE_NEPT = (BASE_TILE_ROWS * BASE_TILE_COLS) // WARP_THREADS
DTYPE_SIZE = 2
INST = "ds_read_b64"
def row_col(threadIdx_x):
local_warpid = threadIdx_x // WARP_THREADS
warp_laneid = threadIdx_x % WARP_THREADS
ret = []
for inner in range(BASE_TILE_NEPT):
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
row = warp_laneid % 16
col = 4 * (warp_laneid // 16)
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
row = warp_laneid % 16
col = 8 * (warp_laneid // 16)
row_offset = 0
col_offset = inner
# swizzle then find row and col
offset = (row + row_offset) * BASE_TILE_COLS + (col + col_offset)
offset *= DTYPE_SIZE
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
swizzle = ((offset % 512) >> 7) << 3
offset = offset ^ swizzle
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
swizzle = ((offset % 1024) >> 9) << 5
offset = offset ^ swizzle
offset //= DTYPE_SIZE
row = offset // BASE_TILE_COLS
col = offset % BASE_TILE_COLS
ret.append((row, col))
return ret
# ===
def shm_phase(inst, threadIdx_x):
match inst:
case "ds_read_b128":
match threadIdx_x:
case 0 | 1 | 2 | 3 | 12 | 13 | 14 | 15 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27: return 0
case 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 16 | 17 | 18 | 19 | 28 | 29 | 30 | 31: return 1
case 32 | 33 | 34 | 35 | 44 | 45 | 46 | 47 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59: return 2
case 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 48 | 49 | 50 | 51 | 60 | 61 | 62 | 63: return 3
case "ds_read_b64":
if threadIdx_x < 32: return 0
else: return 1
case "ds_write_b64":
if threadIdx_x < 16: return 0
elif threadIdx_x < 32: return 1
elif threadIdx_x < 48: return 2
else: return 3
def shm_bank(inst, row, col):
bank = row * (BASE_TILE_COLS // 2) + (col // 2)
match inst:
case "ds_read_b128": bank = bank % 64
case "ds_read_b64": bank = bank % 64
case "ds_write_b64": bank = bank % 32
return bank
def map_range(value, from_min, from_max, to_min, to_max):
ratio = (value - from_min) / (from_max - from_min)
return to_min + ratio * (to_max - to_min)
def shm_bank_gradient(inst, bank):
# rgb color for each bank
# for 16 bit elements, two elements per bank row wise
# gradient from blue to red
amount = map_range(bank, 0, (64 if inst != "ds_write_b64" else 32) - 1, 0, 120)
amount = int(amount)
return (amount, amount // 2, 120 - amount)
def color_code(phase):
match phase:
case 0: return "red"
case 1: return "green"
case 2: return "blue"
case 3: return "yellow"
def rgb_bg(text, color):
return f"\033[48;2;{color[0]};{color[1]};{color[2]}m{text}\033[0m"
def visualize_threads(inst=INST):
for threadIdx_x in range(WARP_THREADS):
row, col = zip(*row_col(threadIdx_x))
print(f"Thread {threadIdx_x:2}: ", end="")
for r, c in zip(row, col):
phase = shm_phase(inst, threadIdx_x)
color = color_code(phase)
print(f"{color}({r:3},{c:3})\033[0m ", end="")
print()
unique_pairs = set()
for threadIdx_x in range(WARP_THREADS):
rc_list = row_col(threadIdx_x)
for rc in rc_list:
unique_pairs.add(rc)
assert len(unique_pairs) == 64 * BASE_TILE_NEPT, f"Expected {64 * BASE_TILE_NEPT} unique pairs, got {len(unique_pairs)}"
def visualize_tile(inst=INST):
tile = [[-1 for _ in range(BASE_TILE_COLS)] for _ in range(BASE_TILE_ROWS)]
for threadIdx_x in range(WARP_THREADS):
rc_list = row_col(threadIdx_x)
for r, c in rc_list:
try:
tile[r][c] = threadIdx_x
except:
pass
bank_conflicts = {}
print("\nTile layout (each number indicates the thread holding that position):")
for r in range(BASE_TILE_ROWS):
for c in range(BASE_TILE_COLS):
phase = shm_phase(inst, tile[r][c])
bank = shm_bank(inst, r, c)
color = color_code(phase)
bank_color = shm_bank_gradient(inst, bank)
if (bank, phase) not in bank_conflicts:
bank_conflicts[(bank, phase)] = []
bank_conflicts[(bank, phase)].append((r, c, tile[r][c]))
if phase == -1:
bank_color = (0, 0, 0)
text = colored(f"{tile[r][c]:2}", color)
text = rgb_bg(text, bank_color)
print(f"{text:2}", end=" ")
print()
for (bank, phase), positions in bank_conflicts.items():
if len(positions) > 1:
unique_threads = set(pos[2] for pos in positions)
if len(unique_threads) > 1:
print(f"{len(unique_threads)} way bank conflict: bank {bank}")
if __name__ == "__main__":
visualize_tile()
# visualize_threads()