mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tk mi350 (#13288)
This commit is contained in:
@@ -7,22 +7,21 @@ from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, ST, RT, RV
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout
|
||||
|
||||
class Group:
|
||||
def __init__(self, warps:int, ker):
|
||||
self.warps = warps
|
||||
self.group_threads = warps * WARP_THREADS
|
||||
self.threadIdx_x = ker.threadIdx_x
|
||||
self.ker = ker
|
||||
|
||||
# helpers
|
||||
@property
|
||||
def laneid(self): return self.threadIdx_x % self.group_threads
|
||||
def laneid(self): return self.ker.threadIdx_x % self.group_threads
|
||||
@property
|
||||
def warpid(self): return self.laneid // WARP_THREADS
|
||||
@property
|
||||
def groupid(self): return self.threadIdx_x // self.group_threads
|
||||
def groupid(self): return self.ker.threadIdx_x // self.group_threads
|
||||
|
||||
# ops that only work on a single warp
|
||||
|
||||
@@ -40,6 +39,7 @@ class Group:
|
||||
return reg.after(reg_store).reshape(reg.shape)
|
||||
|
||||
def zero(self, reg:ALL_TILES): return self.clear(reg, 0)
|
||||
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
|
||||
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
|
||||
|
||||
copy_rid = 300
|
||||
@@ -51,7 +51,22 @@ class Group:
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
|
||||
Group.copy_rid += len(dst.shape)
|
||||
|
||||
dst_store = dst[*rngs_for_shape].store(src[*rngs_for_shape].cast(dst.dtype.base)).end(*rngs_for_shape)
|
||||
src_load = src[*rngs_for_shape]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*rngs_for_shape].store(src_load).end(*rngs_for_shape)
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
|
||||
def transpose(self, dst:UOp|RT, src:UOp|RT):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
for inner in self.ker.range(src.shape[-1], track=False):
|
||||
dst_store = dst[width, height, inner].store(src[height, width, inner]).end(height, width, inner)
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
@@ -60,20 +75,27 @@ class Group:
|
||||
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
|
||||
assert self.warps == 1
|
||||
|
||||
a_base_shape = cast(RT, a).base_shape
|
||||
if a_base_shape.cols == 16:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
|
||||
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
|
||||
wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
|
||||
b_in1 = UOp.vectorize(*([b[inner, width, i] for i in range(2)] + [b[inner, width, 4+i] for i in range(2)]))
|
||||
c_out1 = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
b_in2 = UOp.vectorize(*([b[inner, width, 2+i] for i in range(2)] + [b[inner, width, 6+i] for i in range(2)]))
|
||||
c_out2 = UOp.vectorize(*[c[height, width, 4+i] for i in range(4)])
|
||||
|
||||
out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg)
|
||||
out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out1.gep(i)) for i in range(4)] + [c[height, width, 4+i].store(out2.gep(i)) for i in range(4)]
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
c_store = UOp.group(*c_i).end(height, width, inner)
|
||||
|
||||
self.ker.push_store(c_store, c)
|
||||
@@ -83,20 +105,87 @@ class Group:
|
||||
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
|
||||
assert self.warps == 1
|
||||
|
||||
a_base_shape = cast(RT, a).base_shape
|
||||
if a_base_shape.cols == 16:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
|
||||
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
|
||||
wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
|
||||
b_in1 = UOp.vectorize(*([b[width, inner, i] for i in range(2)] + [b[width, inner, 4+i] for i in range(2)]))
|
||||
c_out1 = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
b_in2 = UOp.vectorize(*([b[width, inner, 2+i] for i in range(2)] + [b[width, inner, 6+i] for i in range(2)]))
|
||||
c_out2 = UOp.vectorize(*[c[height, width, 4+i] for i in range(4)])
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
c_store = UOp.group(*c_i).end(height, width, inner)
|
||||
|
||||
out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg)
|
||||
out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out1.gep(i)) for i in range(4)] + [c[height, width, 4+i].store(out2.gep(i)) for i in range(4)]
|
||||
self.ker.push_store(c_store, c)
|
||||
return c.after(c_store).reshape(c.shape)
|
||||
|
||||
def mma_AtB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
|
||||
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
|
||||
assert self.warps == 1
|
||||
|
||||
a_base_shape = cast(RT, a).base_shape
|
||||
if a_base_shape.cols == 16:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
|
||||
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
c_store = UOp.group(*c_i).end(height, width, inner)
|
||||
|
||||
self.ker.push_store(c_store, c)
|
||||
return c.after(c_store).reshape(c.shape)
|
||||
|
||||
def mma_AtBt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT):
|
||||
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
|
||||
assert self.warps == 1
|
||||
|
||||
a_base_shape = cast(RT, a).base_shape
|
||||
if a_base_shape.cols == 16:
|
||||
wmma_arg = ('WMMA_16_16_16___bf16_float', (16, 16, 16), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
elif a_base_shape.cols == 32:
|
||||
wmma_arg = ('WMMA_16_16_32___bf16_float', (16, 16, 32), dtypes.bfloat16, dtypes.float, 'AMD', 64, (((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2), (9, 2)), ((4, 2), (3, 2))), ())
|
||||
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
|
||||
|
||||
for height in self.ker.range(c.shape[-3], track=False):
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
c_store = UOp.group(*c_i).end(height, width, inner)
|
||||
|
||||
self.ker.push_store(c_store, c)
|
||||
@@ -120,171 +209,213 @@ class Group:
|
||||
self.ker.push_store(a_store, a)
|
||||
return a.after(a_store).reshape(a.shape)
|
||||
|
||||
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp]):
|
||||
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp], init_value:float=0.0):
|
||||
vec, src = cast(UOp, vec), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
|
||||
red_local = self.ker.alloc((self.group_threads, 2), src.dtype.base, AddrSpace.LOCAL)
|
||||
red_reg = self.ker.alloc((2,), src.dtype.base, AddrSpace.REG)
|
||||
red_local = self.ker.alloc((self.group_threads,), src.dtype.base, AddrSpace.LOCAL)
|
||||
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
|
||||
reg_store = red_reg.flatten()[i].store(0.).end(i)
|
||||
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
for outer in self.ker.range(2, track=False):
|
||||
for width in self.ker.range(src.shape[-2], axis_type=AxisType.REDUCE, track=False):
|
||||
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
|
||||
elem_index = inner + 2 * (inner // 2) + outer * 2
|
||||
reg_store = red_reg[outer].store(op(red_reg[outer], src[height, width, elem_index])).end(inner, width, outer)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# store to shared memory
|
||||
for outer in self.ker.range(2, track=False):
|
||||
red_local_store = red_local[self.laneid, outer].store(red_reg[outer]).end(outer)
|
||||
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
|
||||
|
||||
# reduce from shared memory
|
||||
for outer in self.ker.range(2, track=False):
|
||||
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
|
||||
offset = (self.laneid // 4) * 4 + ((self.laneid + inner + 1) % 4)
|
||||
reg_store = red_reg[outer].store(op(red_reg[outer], red_local[offset, outer])).end(inner, outer)
|
||||
for width in self.ker.range(src.shape[-2], axis_type=AxisType.REDUCE, track=False):
|
||||
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
|
||||
reg_store = red_reg[0].store(op(red_reg[0], src[height, width, inner])).end(width, inner)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# store to shared memory
|
||||
red_local_store = red_local[self.laneid].store(red_reg[0])
|
||||
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
|
||||
|
||||
# reduce from shared memory
|
||||
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
|
||||
offset = (self.laneid + (1 + inner) * 16) % self.group_threads
|
||||
reg_store = red_reg[0].store(op(red_reg[0], red_local[offset])).end(inner)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# reduce with vec
|
||||
for outer in self.ker.range(2, track=False):
|
||||
vec_store = vec[height, 0, outer].store(op(vec[height, 0, outer], red_reg[outer])).end(outer, height)
|
||||
vec_store = vec[height, 0].store(op(vec[height, 0], red_reg[0])).end(height)
|
||||
|
||||
self.ker.push_store(vec_store, vec)
|
||||
return vec.after(vec_store).reshape(vec.shape)
|
||||
|
||||
def col_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp], init_value:float=0.0):
|
||||
vec, src = cast(UOp, vec), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
|
||||
red_local = self.ker.alloc((self.group_threads,), src.dtype.base, AddrSpace.LOCAL)
|
||||
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
||||
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
|
||||
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
for height in self.ker.range(src.shape[-3], axis_type=AxisType.REDUCE, track=False):
|
||||
for inner in self.ker.range(4, axis_type=AxisType.REDUCE, track=False):
|
||||
reg_store = red_reg[0].store(op(red_reg[0], src[height, width, inner])).end(height, inner)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# store to shared memory
|
||||
red_local_store = red_local[self.laneid].store(red_reg[0])
|
||||
red_local = red_local.after(red_local_store.barrier()).reshape(red_local.shape)
|
||||
|
||||
# reduce from shared memory
|
||||
for inner in self.ker.range(3, axis_type=AxisType.REDUCE, track=False):
|
||||
offset = (self.laneid + (1 + inner) * 16) % self.group_threads
|
||||
reg_store = red_reg[0].store(op(red_reg[0], red_local[offset])).end(inner)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
# reduce with vec
|
||||
vec_store = vec[width, 0].store(op(vec[width, 0], red_reg[0])).end(width)
|
||||
|
||||
self.ker.push_store(vec_store, vec)
|
||||
return vec.after(vec_store).reshape(vec.shape)
|
||||
|
||||
# ops that can work across multiple warps
|
||||
|
||||
LOAD_INNER = 4
|
||||
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
|
||||
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL:
|
||||
srcf = src.flatten(-2)
|
||||
|
||||
if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4)
|
||||
else: local_warpid = self.warpid
|
||||
warp_laneid = self.threadIdx_x % WARP_THREADS
|
||||
laneid = self.ker.laneid
|
||||
rt, st = cast(RT, dst), cast(ST, src)
|
||||
elements_per_thread = rt.base_shape.elements_per_thread
|
||||
|
||||
for height in self.ker.range(dst.shape[-3], track=False):
|
||||
for width in self.ker.range(dst.shape[-2], track=False):
|
||||
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
|
||||
base_row = (local_warpid * dst.shape[-3] + height) * RT.BASE_TILE_ROWS
|
||||
base_col = width * RT.BASE_TILE_COLS
|
||||
|
||||
if not transpose:
|
||||
row = base_row + (warp_laneid // 4)
|
||||
col = base_col + 2 * (warp_laneid % 4)
|
||||
|
||||
row_offset = ((inner % 4) // 2) * 8
|
||||
col_offset = (inner % 2) + (inner // 4) * 8
|
||||
for inner in self.ker.range(elements_per_thread, track=False):
|
||||
if rt.layout != st.layout:
|
||||
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
|
||||
col = laneid % rt.base_shape.cols
|
||||
else:
|
||||
row = base_row + 2 * (warp_laneid % 4)
|
||||
col = base_col + (warp_laneid // 4)
|
||||
row = laneid % rt.base_shape.rows
|
||||
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
|
||||
|
||||
row_offset = (inner % 2) + (inner // 4) * 8
|
||||
col_offset = ((inner % 4) // 2) * 8
|
||||
srow, scol = cast(ST, src).swizzle(row, col)
|
||||
|
||||
src_i_last = (row + row_offset) * src.shape[-1] + col + col_offset
|
||||
|
||||
dst_store = dst[*dst_idxs, height, width, inner].store(srcf[*idxs[:-2], src_i_last])
|
||||
src_load = src[*idxs[:-2], height, width, srow, scol]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, inner].store(src_load)
|
||||
dst_store = dst_store.end(height, width, inner)
|
||||
elif dst_dtype.addrspace == AddrSpace.LOCAL and src_dtype.addrspace == AddrSpace.GLOBAL:
|
||||
dstf = dst.flatten(-2)
|
||||
|
||||
srcf = src.flatten()
|
||||
row_stride = prod(src.shape[axis+1:])
|
||||
|
||||
idxs = tuple(idx * dst.shape[-2] if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * dst.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
st = cast(ST, dst)
|
||||
idxs = tuple(idx * st.rows if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * st.cols if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
|
||||
|
||||
memcpy_per_row = dst.shape[-1] // Group.LOAD_INNER
|
||||
total_calls = prod(dst.shape[-2:]) // (self.group_threads * Group.LOAD_INNER)
|
||||
for height in self.ker.range(dst.shape[-4], track=False):
|
||||
for width in self.ker.range(dst.shape[-3], track=False):
|
||||
elements_per_thread = st.base_shape.elements_per_thread
|
||||
memcpy_per_row = st.base_shape.cols // elements_per_thread
|
||||
total_calls = st.base_shape.num_elements // (self.group_threads * elements_per_thread)
|
||||
|
||||
for outer in self.ker.range(total_calls, track=False):
|
||||
for inner in self.ker.range(Group.LOAD_INNER, track=False):
|
||||
load_idx = outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * Group.LOAD_INNER) % dst.shape[-1]
|
||||
for outer in self.ker.range(total_calls, track=False):
|
||||
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
|
||||
load_idx = outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * elements_per_thread) % st.base_shape.cols + inner
|
||||
|
||||
dst_i = row * dst.shape[-1] + col + inner
|
||||
src_i += row * row_stride + col + inner
|
||||
srow, scol = cast(ST, dst).swizzle(row, col)
|
||||
|
||||
dst_store = dstf[*dst_idxs, dst_i].store(srcf[src_i]).end(outer, inner)
|
||||
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
|
||||
src_i += row * row_stride + col
|
||||
|
||||
src_load = srcf[src_i]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
|
||||
dst_store = dst_store.end(height, width, outer, inner).barrier()
|
||||
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace ==AddrSpace.GLOBAL:
|
||||
srcf = src.flatten()
|
||||
row_stride = prod(src.shape[axis+1:])
|
||||
|
||||
laneid = self.ker.laneid
|
||||
rt = cast(RT, dst)
|
||||
elements_per_thread = rt.base_shape.elements_per_thread
|
||||
|
||||
idxs = tuple(idx * dst.shape[-3] * rt.base_shape.rows if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * dst.shape[-2] * rt.base_shape.cols if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
|
||||
|
||||
for height in self.ker.range(dst.shape[-3], track=False):
|
||||
for width in self.ker.range(dst.shape[-2], track=False):
|
||||
for inner in self.ker.range(elements_per_thread, track=False):
|
||||
base_row = height * rt.base_shape.rows
|
||||
base_col = width * rt.base_shape.cols
|
||||
|
||||
if rt.layout == TileLayout.COL:
|
||||
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
|
||||
col = laneid % rt.base_shape.cols
|
||||
else:
|
||||
row = laneid % rt.base_shape.rows
|
||||
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
|
||||
|
||||
srow, scol = base_row + row, base_col + col
|
||||
|
||||
src_i += srow * row_stride + scol
|
||||
|
||||
src_load = srcf[src_i]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, inner].store(src_load).end(height, width, inner)
|
||||
else:
|
||||
raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
|
||||
|
||||
return dst.after(dst_store.barrier()).reshape(dst.shape)
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
|
||||
STORE_INNER = 4
|
||||
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
|
||||
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis:int=0):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
|
||||
dstf = dst.flatten(-2)
|
||||
|
||||
if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4)
|
||||
else: local_warpid = self.warpid
|
||||
warp_laneid = self.threadIdx_x % WARP_THREADS
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
|
||||
base_row = (local_warpid * src.shape[-3] + height) * RT.BASE_TILE_ROWS
|
||||
base_col = width * RT.BASE_TILE_COLS
|
||||
|
||||
if not transpose:
|
||||
row = base_row + (warp_laneid // 4)
|
||||
col = base_col + 2 * (warp_laneid % 4)
|
||||
|
||||
row_offset = ((inner % 4) // 2) * 8
|
||||
col_offset = (inner % 2) + (inner // 4) * 8
|
||||
else:
|
||||
row = base_row + 2 * (warp_laneid % 4)
|
||||
col = base_col + (warp_laneid // 4)
|
||||
|
||||
row_offset = (inner % 2) + (inner // 4) * 8
|
||||
col_offset = ((inner % 4) // 2) * 8
|
||||
|
||||
dst_i_last = (row + row_offset) * dst.shape[-1] + col + col_offset
|
||||
|
||||
dst_store = dstf[*idxs[:-2], dst_i_last].store(src[*src_idxs, height, width, inner])
|
||||
dst_store = dst_store.end(height, width, inner)
|
||||
elif src_dtype.addrspace == AddrSpace.LOCAL and dst_dtype.addrspace == AddrSpace.GLOBAL:
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL:
|
||||
dstf = dst.flatten()
|
||||
row_stride = prod(dst.shape[axis+1:])
|
||||
|
||||
idxs = tuple(idx * src.shape[-2] if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * src.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
laneid = self.ker.laneid
|
||||
rt = cast(RT, src)
|
||||
elements_per_thread = rt.base_shape.elements_per_thread
|
||||
|
||||
idxs = tuple(idx * src.shape[-3] * rt.base_shape.rows if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * src.shape[-2] * rt.base_shape.cols if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3]
|
||||
|
||||
srcf = src.flatten(-2)
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
for inner in self.ker.range(elements_per_thread, track=False):
|
||||
base_row = height * rt.base_shape.rows
|
||||
base_col = width * rt.base_shape.cols
|
||||
|
||||
memcpy_per_row = src.shape[-1] // Group.STORE_INNER
|
||||
total_calls = prod(src.shape[-2:]) // (self.group_threads * Group.STORE_INNER)
|
||||
if rt.layout == TileLayout.COL:
|
||||
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
|
||||
col = laneid % rt.base_shape.cols
|
||||
else:
|
||||
row = laneid % rt.base_shape.rows
|
||||
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
|
||||
|
||||
for outer in self.ker.range(total_calls, track=False):
|
||||
for inner in self.ker.range(Group.STORE_INNER, track=False):
|
||||
load_idx = outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * Group.STORE_INNER) % src.shape[-1]
|
||||
srow, scol = base_row + row, base_col + col
|
||||
|
||||
src_i = row * src.shape[-1] + col + inner
|
||||
dst_i += row * row_stride + col + inner
|
||||
dst_i += srow * row_stride + scol
|
||||
|
||||
dst_store = dstf[dst_i].store(srcf[*src_idxs, src_i]).end(outer, inner)
|
||||
src_load = src[*src_idxs, height, width, inner]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dstf[dst_i].store(src_load).end(height, width, inner)
|
||||
else:
|
||||
raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store.barrier()).reshape(dst.shape)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
|
||||
@@ -2,7 +2,7 @@ from contextlib import AbstractContextManager
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, AxisType, AddrSpace
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.group import Group
|
||||
from extra.thunder.tiny.tk.tiles import GL, ST, RT, RV
|
||||
from extra.thunder.tiny.tk.tiles import GL, ST_16X16, ST_16X16_SWIZZLED, ST, RT_16X16, RT, RV, TileLayout, VecLayout
|
||||
|
||||
class _tk_range:
|
||||
user_rid = 0
|
||||
@@ -35,6 +35,8 @@ class Kernel(AbstractContextManager):
|
||||
|
||||
@property
|
||||
def warpid(self): return self.threadIdx_x // WARP_THREADS
|
||||
@property
|
||||
def laneid(self): return self.threadIdx_x % WARP_THREADS
|
||||
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, exc_type, exc_value, traceback): pass
|
||||
@@ -72,9 +74,9 @@ class Kernel(AbstractContextManager):
|
||||
return uop
|
||||
|
||||
def gl(self, shape, dtype): return GL.create(shape, dtype, self)
|
||||
def st(self, shape, dtype): return ST.create(shape, dtype, self)
|
||||
def rt(self, shape, dtype): return RT.create(shape, dtype, self)
|
||||
def rv(self, length, dtype, layout="naive"): return RV.create(length, dtype, layout, self)
|
||||
def st(self, shape, dtype, layout=TileLayout.ROW, base_shape=ST_16X16): return ST.create(shape, dtype, layout, base_shape, self)
|
||||
def rt(self, shape, dtype, layout=TileLayout.ROW, base_shape=RT_16X16): return RT.create(shape, dtype, layout, base_shape, self)
|
||||
def rv(self, length, dtype, layout=VecLayout.ORTHO, rt_base_shape=RT_16X16): return RV.create(length, dtype, layout, rt_base_shape, self)
|
||||
|
||||
def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop))
|
||||
|
||||
@@ -92,4 +94,4 @@ class Kernel(AbstractContextManager):
|
||||
def endrange(self):
|
||||
last_store = self.store_stack.pop()
|
||||
last_range = self.range_stack.pop()
|
||||
return last_store[1].after(last_store[0].barrier().end(last_range._rng)).reshape(last_store[1].shape)
|
||||
return last_store[1].after(last_store[0].end(last_range._rng)).reshape(last_store[1].shape)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from enum import Enum, auto
|
||||
import functools
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from typing import Callable
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import AddrSpace, DType
|
||||
from tinygrad.mixin import MathMixin
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
@@ -66,7 +69,10 @@ class TileMathMixin(MathMixin):
|
||||
elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0]))))
|
||||
elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
|
||||
else:
|
||||
if isinstance(self, RT) and isinstance(src[0], RV): uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0, (idx[2]%4)//2])))
|
||||
if isinstance(self, RT) and isinstance(src[0], RV):
|
||||
match self.layout:
|
||||
case TileLayout.ROW: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0])))
|
||||
case TileLayout.COL: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[1], 0])))
|
||||
else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx])))
|
||||
else: raise NotImplementedError
|
||||
return self.ruop(uop)
|
||||
@@ -80,76 +86,186 @@ class TileMathMixin(MathMixin):
|
||||
|
||||
@autowrap(UOp)
|
||||
class GL:
|
||||
def __init__(self, uop, ker):
|
||||
def __init__(self, uop:UOp, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
|
||||
def ruop(self, uop):
|
||||
def ruop(self, uop:UOp):
|
||||
return GL(uop, self.ker)
|
||||
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
def create(cls, shape, dtype:DType, ker):
|
||||
uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
|
||||
return cls(uop, ker)
|
||||
|
||||
class TileLayout(Enum):
|
||||
ROW = auto()
|
||||
COL = auto()
|
||||
|
||||
class VecLayout(Enum):
|
||||
ORTHO = auto()
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseShape:
|
||||
rows: int
|
||||
cols: int
|
||||
|
||||
@property
|
||||
def num_elements(self): return self.rows * self.cols
|
||||
@property
|
||||
def elements_per_thread(self): return self.num_elements // WARP_THREADS
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class STBaseShape(BaseShape):
|
||||
_swizzle: Callable[[UOp, DType], UOp]
|
||||
bytes_per_thread: Callable[[DType], int]
|
||||
|
||||
def swizzle(self, row, col, dtype:DType):
|
||||
offset = row * self.cols + col
|
||||
offset *= dtype.itemsize
|
||||
offset = self._swizzle(offset, dtype)
|
||||
offset //= dtype.itemsize
|
||||
return offset
|
||||
|
||||
def st_16x16_swizzle(offset:UOp, _): return offset
|
||||
def st_16x16_bpt(dtype:DType):
|
||||
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
||||
else: raise NotImplementedError
|
||||
ST_16X16 = STBaseShape(16, 16, st_16x16_swizzle, st_16x16_bpt)
|
||||
|
||||
def st_16x16_swizzled_swizzle(offset:UOp, dtype:DType):
|
||||
if dtype.itemsize == 2:
|
||||
swizzle = ((offset % 512) >> 7) << 3
|
||||
return offset ^ swizzle
|
||||
elif dtype.itemsize == 4:
|
||||
return offset
|
||||
else: raise NotImplementedError
|
||||
def st_16x16_swizzled_bpt(dtype:DType):
|
||||
if dtype.itemsize == 2: return 4
|
||||
elif dtype.itemsize == 4: return 16
|
||||
else: raise NotImplementedError
|
||||
ST_16X16_SWIZZLED = STBaseShape(16, 16, st_16x16_swizzled_swizzle, st_16x16_swizzled_bpt)
|
||||
|
||||
def st_32x32_swizzle(offset:UOp, dtype:DType):
|
||||
if dtype.itemsize == 2:
|
||||
first_swizzle = ((offset % 1024) >> 9) << 5
|
||||
second_swizzle = ((offset % 2048) >> 10) << 4
|
||||
return offset ^ first_swizzle ^ second_swizzle
|
||||
elif dtype.itemsize == 4:
|
||||
return offset
|
||||
else: raise NotImplementedError
|
||||
def st_32x32_bpt(dtype:DType):
|
||||
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
||||
else: raise NotImplementedError
|
||||
ST_32X32 = STBaseShape(32, 32, st_32x32_swizzle, st_32x32_bpt)
|
||||
|
||||
def st_16x32_swizzle(offset:UOp, dtype:DType):
|
||||
if dtype.itemsize == 2:
|
||||
swizzle = ((offset % 1024) >> 9) << 5
|
||||
return offset ^ swizzle
|
||||
elif dtype.itemsize == 4:
|
||||
return offset
|
||||
else: raise NotImplementedError
|
||||
def st_16x32_bpt(dtype:DType):
|
||||
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
||||
else: raise NotImplementedError
|
||||
ST_16X32 = STBaseShape(16, 32, st_16x32_swizzle, st_16x32_bpt)
|
||||
|
||||
def st_32x16_swizzle(offset:UOp, dtype:DType):
|
||||
if dtype.itemsize == 2:
|
||||
swizzle = ((offset % 1024) >> 9) << 4
|
||||
return offset ^ swizzle
|
||||
elif dtype.itemsize == 4:
|
||||
return offset
|
||||
else: raise NotImplementedError
|
||||
def st_32x16_bpt(dtype:DType):
|
||||
if dtype.itemsize == 2 or dtype.itemsize == 4: return 16
|
||||
else: raise NotImplementedError
|
||||
ST_32X16 = STBaseShape(32, 16, st_32x16_swizzle, st_32x16_bpt)
|
||||
|
||||
@autowrap(UOp)
|
||||
class ST:
|
||||
def __init__(self, uop, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
def __init__(self, uop:UOp, rows:int, cols:int, layout:TileLayout, base_shape:STBaseShape, ker):
|
||||
self._uop, self.rows, self.cols, self.layout, self.base_shape, self.ker = uop, rows, cols, layout, base_shape, ker
|
||||
|
||||
def ruop(self, uop):
|
||||
return ST(uop, self.ker)
|
||||
def ruop(self, uop:UOp):
|
||||
return ST(uop, self.rows, self.cols, self.layout, self.base_shape, self.ker)
|
||||
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
|
||||
return cls(uop, ker)
|
||||
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:STBaseShape, ker):
|
||||
rows = shape[-2]
|
||||
cols = shape[-1]
|
||||
assert rows % base_shape.rows == 0
|
||||
assert cols % base_shape.cols == 0
|
||||
assert cols % base_shape.elements_per_thread == 0
|
||||
|
||||
height = rows // base_shape.rows
|
||||
width = cols // base_shape.cols
|
||||
|
||||
uop = ker.alloc(shape[:-2] + (height, width, base_shape.rows, base_shape.cols), dtype, AddrSpace.LOCAL)
|
||||
return cls(uop, rows, cols, layout, base_shape, ker)
|
||||
|
||||
def swizzle(self, row, col):
|
||||
swizzled_offset = self.base_shape.swizzle(row, col, self._uop.dtype.base.scalar())
|
||||
|
||||
row = swizzled_offset // self.base_shape.cols
|
||||
col = swizzled_offset % self.base_shape.cols
|
||||
|
||||
return row, col
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RTBaseShape(BaseShape):
|
||||
stride: int
|
||||
|
||||
@property
|
||||
def num_strides(self):
|
||||
return self.elements_per_thread // self.stride
|
||||
|
||||
RT_16X16 = RTBaseShape(rows=16, cols=16, stride=4)
|
||||
RT_32X32 = RTBaseShape(rows=32, cols=32, stride=4)
|
||||
RT_32X32_8 = RTBaseShape(rows=32, cols=32, stride=8)
|
||||
RT_16X32 = RTBaseShape(rows=16, cols=32, stride=8)
|
||||
RT_32X16 = RTBaseShape(rows=32, cols=16, stride=8)
|
||||
RT_32X16_4 = RTBaseShape(rows=32, cols=16, stride=4)
|
||||
RT_16X32_4 = RTBaseShape(rows=16, cols=32, stride=4)
|
||||
|
||||
@autowrap(UOp)
|
||||
class RT(TileMathMixin):
|
||||
BASE_TILE_ROWS, BASE_TILE_COLS = 16, 16
|
||||
BASE_TILE_NE = BASE_TILE_ROWS * BASE_TILE_COLS
|
||||
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
|
||||
def __init__(self, uop:UOp, layout:TileLayout, base_shape:RTBaseShape, ker):
|
||||
self._uop, self.layout, self.base_shape, self.ker = uop, layout, base_shape, ker
|
||||
|
||||
def __init__(self, uop, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
|
||||
def ruop(self, uop):
|
||||
return RT(uop, self.ker)
|
||||
def ruop(self, uop:UOp):
|
||||
return RT(uop, self.layout, self.base_shape, self.ker)
|
||||
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
def create(cls, shape, dtype:DType, layout:TileLayout, base_shape:RTBaseShape, ker):
|
||||
assert len(shape) == 2
|
||||
assert shape[0] % RT.BASE_TILE_ROWS == 0
|
||||
assert shape[1] % RT.BASE_TILE_COLS == 0
|
||||
assert shape[0] % base_shape.rows == 0
|
||||
assert shape[1] % base_shape.cols == 0
|
||||
|
||||
height = shape[0] // RT.BASE_TILE_ROWS
|
||||
width = shape[1] // RT.BASE_TILE_COLS
|
||||
height = shape[0] // base_shape.rows
|
||||
width = shape[1] // base_shape.cols
|
||||
|
||||
uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG)
|
||||
return cls(uop, ker)
|
||||
uop = ker.alloc((height, width, base_shape.elements_per_thread), dtype, AddrSpace.REG)
|
||||
return cls(uop, layout, base_shape, ker)
|
||||
|
||||
@autowrap(UOp)
|
||||
class RV(TileMathMixin):
|
||||
def __init__(self, uop, layout, ker):
|
||||
def __init__(self, uop:UOp, layout:VecLayout, ker):
|
||||
self._uop, self.layout, self.ker = uop, layout, ker
|
||||
|
||||
def ruop(self, uop):
|
||||
def ruop(self, uop:UOp):
|
||||
return RV(uop, self.layout, self.ker)
|
||||
|
||||
@classmethod
|
||||
def create(cls, length, dtype, layout, ker):
|
||||
tiles = length // RT.BASE_TILE_ROWS
|
||||
def create(cls, length, dtype:DType, layout:VecLayout, base_shape:RTBaseShape, ker):
|
||||
tiles = length // base_shape.rows
|
||||
|
||||
match layout:
|
||||
case "naive":
|
||||
inner_dim = 1
|
||||
outer_dim = (tiles + 1) // 2
|
||||
case "ortho":
|
||||
case VecLayout.ORTHO:
|
||||
inner_dim = 1
|
||||
outer_dim = tiles
|
||||
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
|
||||
|
||||
uop = ker.alloc((outer_dim, inner_dim, 2), dtype, AddrSpace.REG)
|
||||
uop = ker.alloc((outer_dim, inner_dim), dtype, AddrSpace.REG)
|
||||
return RV(uop, layout, ker)
|
||||
|
||||
ALL_TILES = UOp | GL | ST | RT | RV
|
||||
|
||||
156
extra/thunder/tiny/visualize_tile.py
Normal file
156
extra/thunder/tiny/visualize_tile.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from tinygrad.helpers import colored
|
||||
|
||||
WARP_THREADS = 64
|
||||
BASE_TILE_ROWS = 16
|
||||
BASE_TILE_COLS = 16
|
||||
BASE_TILE_NEPT = (BASE_TILE_ROWS * BASE_TILE_COLS) // WARP_THREADS
|
||||
DTYPE_SIZE = 2
|
||||
INST = "ds_read_b64"
|
||||
|
||||
def row_col(threadIdx_x):
|
||||
local_warpid = threadIdx_x // WARP_THREADS
|
||||
warp_laneid = threadIdx_x % WARP_THREADS
|
||||
|
||||
ret = []
|
||||
|
||||
for inner in range(BASE_TILE_NEPT):
|
||||
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
|
||||
row = warp_laneid % 16
|
||||
col = 4 * (warp_laneid // 16)
|
||||
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
|
||||
row = warp_laneid % 16
|
||||
col = 8 * (warp_laneid // 16)
|
||||
|
||||
row_offset = 0
|
||||
col_offset = inner
|
||||
|
||||
# swizzle then find row and col
|
||||
offset = (row + row_offset) * BASE_TILE_COLS + (col + col_offset)
|
||||
offset *= DTYPE_SIZE
|
||||
|
||||
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
|
||||
swizzle = ((offset % 512) >> 7) << 3
|
||||
offset = offset ^ swizzle
|
||||
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
|
||||
swizzle = ((offset % 1024) >> 9) << 5
|
||||
offset = offset ^ swizzle
|
||||
|
||||
offset //= DTYPE_SIZE
|
||||
|
||||
row = offset // BASE_TILE_COLS
|
||||
col = offset % BASE_TILE_COLS
|
||||
|
||||
ret.append((row, col))
|
||||
|
||||
return ret
|
||||
|
||||
# ===
|
||||
|
||||
def shm_phase(inst, threadIdx_x):
|
||||
match inst:
|
||||
case "ds_read_b128":
|
||||
match threadIdx_x:
|
||||
case 0 | 1 | 2 | 3 | 12 | 13 | 14 | 15 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27: return 0
|
||||
case 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 16 | 17 | 18 | 19 | 28 | 29 | 30 | 31: return 1
|
||||
case 32 | 33 | 34 | 35 | 44 | 45 | 46 | 47 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59: return 2
|
||||
case 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 48 | 49 | 50 | 51 | 60 | 61 | 62 | 63: return 3
|
||||
case "ds_read_b64":
|
||||
if threadIdx_x < 32: return 0
|
||||
else: return 1
|
||||
case "ds_write_b64":
|
||||
if threadIdx_x < 16: return 0
|
||||
elif threadIdx_x < 32: return 1
|
||||
elif threadIdx_x < 48: return 2
|
||||
else: return 3
|
||||
|
||||
def shm_bank(inst, row, col):
|
||||
bank = row * (BASE_TILE_COLS // 2) + (col // 2)
|
||||
|
||||
match inst:
|
||||
case "ds_read_b128": bank = bank % 64
|
||||
case "ds_read_b64": bank = bank % 64
|
||||
case "ds_write_b64": bank = bank % 32
|
||||
|
||||
return bank
|
||||
|
||||
def map_range(value, from_min, from_max, to_min, to_max):
|
||||
ratio = (value - from_min) / (from_max - from_min)
|
||||
return to_min + ratio * (to_max - to_min)
|
||||
|
||||
def shm_bank_gradient(inst, bank):
|
||||
# rgb color for each bank
|
||||
# for 16 bit elements, two elements per bank row wise
|
||||
|
||||
# gradient from blue to red
|
||||
amount = map_range(bank, 0, (64 if inst != "ds_write_b64" else 32) - 1, 0, 120)
|
||||
amount = int(amount)
|
||||
return (amount, amount // 2, 120 - amount)
|
||||
|
||||
def color_code(phase):
|
||||
match phase:
|
||||
case 0: return "red"
|
||||
case 1: return "green"
|
||||
case 2: return "blue"
|
||||
case 3: return "yellow"
|
||||
|
||||
def rgb_bg(text, color):
|
||||
return f"\033[48;2;{color[0]};{color[1]};{color[2]}m{text}\033[0m"
|
||||
|
||||
def visualize_threads(inst=INST):
|
||||
for threadIdx_x in range(WARP_THREADS):
|
||||
row, col = zip(*row_col(threadIdx_x))
|
||||
print(f"Thread {threadIdx_x:2}: ", end="")
|
||||
for r, c in zip(row, col):
|
||||
phase = shm_phase(inst, threadIdx_x)
|
||||
color = color_code(phase)
|
||||
print(f"{color}({r:3},{c:3})\033[0m ", end="")
|
||||
print()
|
||||
|
||||
unique_pairs = set()
|
||||
for threadIdx_x in range(WARP_THREADS):
|
||||
rc_list = row_col(threadIdx_x)
|
||||
for rc in rc_list:
|
||||
unique_pairs.add(rc)
|
||||
assert len(unique_pairs) == 64 * BASE_TILE_NEPT, f"Expected {64 * BASE_TILE_NEPT} unique pairs, got {len(unique_pairs)}"
|
||||
|
||||
def visualize_tile(inst=INST):
|
||||
tile = [[-1 for _ in range(BASE_TILE_COLS)] for _ in range(BASE_TILE_ROWS)]
|
||||
for threadIdx_x in range(WARP_THREADS):
|
||||
rc_list = row_col(threadIdx_x)
|
||||
for r, c in rc_list:
|
||||
try:
|
||||
tile[r][c] = threadIdx_x
|
||||
except:
|
||||
pass
|
||||
|
||||
bank_conflicts = {}
|
||||
|
||||
print("\nTile layout (each number indicates the thread holding that position):")
|
||||
for r in range(BASE_TILE_ROWS):
|
||||
for c in range(BASE_TILE_COLS):
|
||||
phase = shm_phase(inst, tile[r][c])
|
||||
bank = shm_bank(inst, r, c)
|
||||
color = color_code(phase)
|
||||
bank_color = shm_bank_gradient(inst, bank)
|
||||
|
||||
if (bank, phase) not in bank_conflicts:
|
||||
bank_conflicts[(bank, phase)] = []
|
||||
bank_conflicts[(bank, phase)].append((r, c, tile[r][c]))
|
||||
|
||||
if phase == -1:
|
||||
bank_color = (0, 0, 0)
|
||||
|
||||
text = colored(f"{tile[r][c]:2}", color)
|
||||
text = rgb_bg(text, bank_color)
|
||||
print(f"{text:2}", end=" ")
|
||||
print()
|
||||
|
||||
for (bank, phase), positions in bank_conflicts.items():
|
||||
if len(positions) > 1:
|
||||
unique_threads = set(pos[2] for pos in positions)
|
||||
if len(unique_threads) > 1:
|
||||
print(f"{len(unique_threads)} way bank conflict: bank {bank}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_tile()
|
||||
# visualize_threads()
|
||||
Reference in New Issue
Block a user