tk: rename rt tile dims to base (#13265)

This commit is contained in:
wozeparrot
2025-11-13 18:43:02 -08:00
committed by GitHub
parent 7eb0d8e744
commit 777cbec5b3
2 changed files with 13 additions and 13 deletions

View File

@@ -178,14 +178,14 @@ class Group:
for width in self.ker.range(dst.shape[-2], track=False):
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
if not transpose:
row = (local_warpid * dst.shape[-3] + height) * RT.TILE_ROW_DIM + (warp_laneid // 4)
col = width * RT.TILE_COL_DIM + 2 * (warp_laneid % 4)
row = (local_warpid * dst.shape[-3] + height) * RT.BASE_TILE_ROWS + (warp_laneid // 4)
col = width * RT.BASE_TILE_COLS + 2 * (warp_laneid % 4)
row_offset = ((inner % 4) // 2) * 8
col_offset = (inner % 2) + (inner // 4) * 8
else:
row = (local_warpid * dst.shape[-3] + height) * RT.TILE_ROW_DIM + 2 * (warp_laneid % 4)
col = width * RT.TILE_COL_DIM + (warp_laneid // 4)
row = (local_warpid * dst.shape[-3] + height) * RT.BASE_TILE_ROWS + 2 * (warp_laneid % 4)
col = width * RT.BASE_TILE_COLS + (warp_laneid // 4)
row_offset = (inner % 2) + (inner // 4) * 8
col_offset = ((inner % 4) // 2) * 8
@@ -237,8 +237,8 @@ class Group:
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
row = (local_warpid * src.shape[-3] + height) * RT.TILE_ROW_DIM + (warp_laneid // 4)
col = width * RT.TILE_COL_DIM + 2 * (warp_laneid % 4)
row = (local_warpid * src.shape[-3] + height) * RT.BASE_TILE_ROWS + (warp_laneid // 4)
col = width * RT.BASE_TILE_COLS + 2 * (warp_laneid % 4)
row_offset = ((inner % 4) // 2) * 8
col_offset = (inner % 2) + (inner // 4) * 8

View File

@@ -100,8 +100,8 @@ class ST:
@autowrap(UOp)
class RT(TileMathMixin):
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
BASE_TILE_ROWS, BASE_TILE_COLS = 16, 16
BASE_TILE_NE = BASE_TILE_ROWS * BASE_TILE_COLS
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
def __init__(self, uop, ker):
@@ -110,11 +110,11 @@ class RT(TileMathMixin):
@classmethod
def create(cls, shape, dtype, ker):
assert len(shape) == 2
assert shape[0] % RT.TILE_ROW_DIM == 0
assert shape[1] % RT.TILE_COL_DIM == 0
assert shape[0] % RT.BASE_TILE_ROWS == 0
assert shape[1] % RT.BASE_TILE_COLS == 0
height = shape[0] // RT.TILE_ROW_DIM
width = shape[1] // RT.TILE_COL_DIM
height = shape[0] // RT.BASE_TILE_ROWS
width = shape[1] // RT.BASE_TILE_COLS
uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG)
return cls(uop, ker)
@@ -126,7 +126,7 @@ class RV(TileMathMixin):
@classmethod
def create(cls, length, dtype, layout, ker):
tiles = length // RT.TILE_ROW_DIM
tiles = length // RT.BASE_TILE_ROWS
match layout:
case "naive":