mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tk: rename rt tile dims to base (#13265)
This commit is contained in:
@@ -178,14 +178,14 @@ class Group:
|
||||
for width in self.ker.range(dst.shape[-2], track=False):
|
||||
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
|
||||
if not transpose:
|
||||
row = (local_warpid * dst.shape[-3] + height) * RT.TILE_ROW_DIM + (warp_laneid // 4)
|
||||
col = width * RT.TILE_COL_DIM + 2 * (warp_laneid % 4)
|
||||
row = (local_warpid * dst.shape[-3] + height) * RT.BASE_TILE_ROWS + (warp_laneid // 4)
|
||||
col = width * RT.BASE_TILE_COLS + 2 * (warp_laneid % 4)
|
||||
|
||||
row_offset = ((inner % 4) // 2) * 8
|
||||
col_offset = (inner % 2) + (inner // 4) * 8
|
||||
else:
|
||||
row = (local_warpid * dst.shape[-3] + height) * RT.TILE_ROW_DIM + 2 * (warp_laneid % 4)
|
||||
col = width * RT.TILE_COL_DIM + (warp_laneid // 4)
|
||||
row = (local_warpid * dst.shape[-3] + height) * RT.BASE_TILE_ROWS + 2 * (warp_laneid % 4)
|
||||
col = width * RT.BASE_TILE_COLS + (warp_laneid // 4)
|
||||
|
||||
row_offset = (inner % 2) + (inner // 4) * 8
|
||||
col_offset = ((inner % 4) // 2) * 8
|
||||
@@ -237,8 +237,8 @@ class Group:
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
for inner in self.ker.range(RT.BASE_TILE_NEPT, track=False):
|
||||
row = (local_warpid * src.shape[-3] + height) * RT.TILE_ROW_DIM + (warp_laneid // 4)
|
||||
col = width * RT.TILE_COL_DIM + 2 * (warp_laneid % 4)
|
||||
row = (local_warpid * src.shape[-3] + height) * RT.BASE_TILE_ROWS + (warp_laneid // 4)
|
||||
col = width * RT.BASE_TILE_COLS + 2 * (warp_laneid % 4)
|
||||
|
||||
row_offset = ((inner % 4) // 2) * 8
|
||||
col_offset = (inner % 2) + (inner // 4) * 8
|
||||
|
||||
@@ -100,8 +100,8 @@ class ST:
|
||||
|
||||
@autowrap(UOp)
|
||||
class RT(TileMathMixin):
|
||||
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
|
||||
BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
|
||||
BASE_TILE_ROWS, BASE_TILE_COLS = 16, 16
|
||||
BASE_TILE_NE = BASE_TILE_ROWS * BASE_TILE_COLS
|
||||
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
|
||||
|
||||
def __init__(self, uop, ker):
|
||||
@@ -110,11 +110,11 @@ class RT(TileMathMixin):
|
||||
@classmethod
|
||||
def create(cls, shape, dtype, ker):
|
||||
assert len(shape) == 2
|
||||
assert shape[0] % RT.TILE_ROW_DIM == 0
|
||||
assert shape[1] % RT.TILE_COL_DIM == 0
|
||||
assert shape[0] % RT.BASE_TILE_ROWS == 0
|
||||
assert shape[1] % RT.BASE_TILE_COLS == 0
|
||||
|
||||
height = shape[0] // RT.TILE_ROW_DIM
|
||||
width = shape[1] // RT.TILE_COL_DIM
|
||||
height = shape[0] // RT.BASE_TILE_ROWS
|
||||
width = shape[1] // RT.BASE_TILE_COLS
|
||||
|
||||
uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG)
|
||||
return cls(uop, ker)
|
||||
@@ -126,7 +126,7 @@ class RV(TileMathMixin):
|
||||
|
||||
@classmethod
|
||||
def create(cls, length, dtype, layout, ker):
|
||||
tiles = length // RT.TILE_ROW_DIM
|
||||
tiles = length // RT.BASE_TILE_ROWS
|
||||
|
||||
match layout:
|
||||
case "naive":
|
||||
|
||||
Reference in New Issue
Block a user