Files
tinygrad/extra/thunder/tiny/visualize_tile.py
2025-11-25 15:49:44 -08:00

157 lines
4.6 KiB
Python

from tinygrad.helpers import colored
WARP_THREADS = 64
BASE_TILE_ROWS = 16
BASE_TILE_COLS = 16
BASE_TILE_NEPT = (BASE_TILE_ROWS * BASE_TILE_COLS) // WARP_THREADS
DTYPE_SIZE = 2
INST = "ds_read_b64"
def row_col(threadIdx_x):
local_warpid = threadIdx_x // WARP_THREADS
warp_laneid = threadIdx_x % WARP_THREADS
ret = []
for inner in range(BASE_TILE_NEPT):
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
row = warp_laneid % 16
col = 4 * (warp_laneid // 16)
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
row = warp_laneid % 16
col = 8 * (warp_laneid // 16)
row_offset = 0
col_offset = inner
# swizzle then find row and col
offset = (row + row_offset) * BASE_TILE_COLS + (col + col_offset)
offset *= DTYPE_SIZE
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
swizzle = ((offset % 512) >> 7) << 3
offset = offset ^ swizzle
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
swizzle = ((offset % 1024) >> 9) << 5
offset = offset ^ swizzle
offset //= DTYPE_SIZE
row = offset // BASE_TILE_COLS
col = offset % BASE_TILE_COLS
ret.append((row, col))
return ret
# ===
def shm_phase(inst, threadIdx_x):
match inst:
case "ds_read_b128":
match threadIdx_x:
case 0 | 1 | 2 | 3 | 12 | 13 | 14 | 15 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27: return 0
case 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 16 | 17 | 18 | 19 | 28 | 29 | 30 | 31: return 1
case 32 | 33 | 34 | 35 | 44 | 45 | 46 | 47 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59: return 2
case 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 48 | 49 | 50 | 51 | 60 | 61 | 62 | 63: return 3
case "ds_read_b64":
if threadIdx_x < 32: return 0
else: return 1
case "ds_write_b64":
if threadIdx_x < 16: return 0
elif threadIdx_x < 32: return 1
elif threadIdx_x < 48: return 2
else: return 3
def shm_bank(inst, row, col):
bank = row * (BASE_TILE_COLS // 2) + (col // 2)
match inst:
case "ds_read_b128": bank = bank % 64
case "ds_read_b64": bank = bank % 64
case "ds_write_b64": bank = bank % 32
return bank
def map_range(value, from_min, from_max, to_min, to_max):
ratio = (value - from_min) / (from_max - from_min)
return to_min + ratio * (to_max - to_min)
def shm_bank_gradient(inst, bank):
# rgb color for each bank
# for 16 bit elements, two elements per bank row wise
# gradient from blue to red
amount = map_range(bank, 0, (64 if inst != "ds_write_b64" else 32) - 1, 0, 120)
amount = int(amount)
return (amount, amount // 2, 120 - amount)
def color_code(phase):
match phase:
case 0: return "red"
case 1: return "green"
case 2: return "blue"
case 3: return "yellow"
def rgb_bg(text, color):
return f"\033[48;2;{color[0]};{color[1]};{color[2]}m{text}\033[0m"
def visualize_threads(inst=INST):
for threadIdx_x in range(WARP_THREADS):
row, col = zip(*row_col(threadIdx_x))
print(f"Thread {threadIdx_x:2}: ", end="")
for r, c in zip(row, col):
phase = shm_phase(inst, threadIdx_x)
color = color_code(phase)
print(f"{color}({r:3},{c:3})\033[0m ", end="")
print()
unique_pairs = set()
for threadIdx_x in range(WARP_THREADS):
rc_list = row_col(threadIdx_x)
for rc in rc_list:
unique_pairs.add(rc)
assert len(unique_pairs) == 64 * BASE_TILE_NEPT, f"Expected {64 * BASE_TILE_NEPT} unique pairs, got {len(unique_pairs)}"
def visualize_tile(inst=INST):
tile = [[-1 for _ in range(BASE_TILE_COLS)] for _ in range(BASE_TILE_ROWS)]
for threadIdx_x in range(WARP_THREADS):
rc_list = row_col(threadIdx_x)
for r, c in rc_list:
try:
tile[r][c] = threadIdx_x
except:
pass
bank_conflicts = {}
print("\nTile layout (each number indicates the thread holding that position):")
for r in range(BASE_TILE_ROWS):
for c in range(BASE_TILE_COLS):
phase = shm_phase(inst, tile[r][c])
bank = shm_bank(inst, r, c)
color = color_code(phase)
bank_color = shm_bank_gradient(inst, bank)
if (bank, phase) not in bank_conflicts:
bank_conflicts[(bank, phase)] = []
bank_conflicts[(bank, phase)].append((r, c, tile[r][c]))
if phase == -1:
bank_color = (0, 0, 0)
text = colored(f"{tile[r][c]:2}", color)
text = rgb_bg(text, bank_color)
print(f"{text:2}", end=" ")
print()
for (bank, phase), positions in bank_conflicts.items():
if len(positions) > 1:
unique_threads = set(pos[2] for pos in positions)
if len(unique_threads) > 1:
print(f"{len(unique_threads)} way bank conflict: bank {bank}")
if __name__ == "__main__":
visualize_tile()
# visualize_threads()