mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
157 lines
4.6 KiB
Python
157 lines
4.6 KiB
Python
from tinygrad.helpers import colored
|
|
|
|
WARP_THREADS = 64
|
|
BASE_TILE_ROWS = 16
|
|
BASE_TILE_COLS = 16
|
|
BASE_TILE_NEPT = (BASE_TILE_ROWS * BASE_TILE_COLS) // WARP_THREADS
|
|
DTYPE_SIZE = 2
|
|
INST = "ds_read_b64"
|
|
|
|
def row_col(threadIdx_x):
|
|
local_warpid = threadIdx_x // WARP_THREADS
|
|
warp_laneid = threadIdx_x % WARP_THREADS
|
|
|
|
ret = []
|
|
|
|
for inner in range(BASE_TILE_NEPT):
|
|
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
|
|
row = warp_laneid % 16
|
|
col = 4 * (warp_laneid // 16)
|
|
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
|
|
row = warp_laneid % 16
|
|
col = 8 * (warp_laneid // 16)
|
|
|
|
row_offset = 0
|
|
col_offset = inner
|
|
|
|
# swizzle then find row and col
|
|
offset = (row + row_offset) * BASE_TILE_COLS + (col + col_offset)
|
|
offset *= DTYPE_SIZE
|
|
|
|
if BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 16:
|
|
swizzle = ((offset % 512) >> 7) << 3
|
|
offset = offset ^ swizzle
|
|
elif BASE_TILE_ROWS == 16 and BASE_TILE_COLS == 32:
|
|
swizzle = ((offset % 1024) >> 9) << 5
|
|
offset = offset ^ swizzle
|
|
|
|
offset //= DTYPE_SIZE
|
|
|
|
row = offset // BASE_TILE_COLS
|
|
col = offset % BASE_TILE_COLS
|
|
|
|
ret.append((row, col))
|
|
|
|
return ret
|
|
|
|
# ===
|
|
|
|
def shm_phase(inst, threadIdx_x):
|
|
match inst:
|
|
case "ds_read_b128":
|
|
match threadIdx_x:
|
|
case 0 | 1 | 2 | 3 | 12 | 13 | 14 | 15 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27: return 0
|
|
case 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 16 | 17 | 18 | 19 | 28 | 29 | 30 | 31: return 1
|
|
case 32 | 33 | 34 | 35 | 44 | 45 | 46 | 47 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59: return 2
|
|
case 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 48 | 49 | 50 | 51 | 60 | 61 | 62 | 63: return 3
|
|
case "ds_read_b64":
|
|
if threadIdx_x < 32: return 0
|
|
else: return 1
|
|
case "ds_write_b64":
|
|
if threadIdx_x < 16: return 0
|
|
elif threadIdx_x < 32: return 1
|
|
elif threadIdx_x < 48: return 2
|
|
else: return 3
|
|
|
|
def shm_bank(inst, row, col):
|
|
bank = row * (BASE_TILE_COLS // 2) + (col // 2)
|
|
|
|
match inst:
|
|
case "ds_read_b128": bank = bank % 64
|
|
case "ds_read_b64": bank = bank % 64
|
|
case "ds_write_b64": bank = bank % 32
|
|
|
|
return bank
|
|
|
|
def map_range(value, from_min, from_max, to_min, to_max):
|
|
ratio = (value - from_min) / (from_max - from_min)
|
|
return to_min + ratio * (to_max - to_min)
|
|
|
|
def shm_bank_gradient(inst, bank):
|
|
# rgb color for each bank
|
|
# for 16 bit elements, two elements per bank row wise
|
|
|
|
# gradient from blue to red
|
|
amount = map_range(bank, 0, (64 if inst != "ds_write_b64" else 32) - 1, 0, 120)
|
|
amount = int(amount)
|
|
return (amount, amount // 2, 120 - amount)
|
|
|
|
def color_code(phase):
|
|
match phase:
|
|
case 0: return "red"
|
|
case 1: return "green"
|
|
case 2: return "blue"
|
|
case 3: return "yellow"
|
|
|
|
def rgb_bg(text, color):
|
|
return f"\033[48;2;{color[0]};{color[1]};{color[2]}m{text}\033[0m"
|
|
|
|
def visualize_threads(inst=INST):
|
|
for threadIdx_x in range(WARP_THREADS):
|
|
row, col = zip(*row_col(threadIdx_x))
|
|
print(f"Thread {threadIdx_x:2}: ", end="")
|
|
for r, c in zip(row, col):
|
|
phase = shm_phase(inst, threadIdx_x)
|
|
color = color_code(phase)
|
|
print(f"{color}({r:3},{c:3})\033[0m ", end="")
|
|
print()
|
|
|
|
unique_pairs = set()
|
|
for threadIdx_x in range(WARP_THREADS):
|
|
rc_list = row_col(threadIdx_x)
|
|
for rc in rc_list:
|
|
unique_pairs.add(rc)
|
|
assert len(unique_pairs) == 64 * BASE_TILE_NEPT, f"Expected {64 * BASE_TILE_NEPT} unique pairs, got {len(unique_pairs)}"
|
|
|
|
def visualize_tile(inst=INST):
|
|
tile = [[-1 for _ in range(BASE_TILE_COLS)] for _ in range(BASE_TILE_ROWS)]
|
|
for threadIdx_x in range(WARP_THREADS):
|
|
rc_list = row_col(threadIdx_x)
|
|
for r, c in rc_list:
|
|
try:
|
|
tile[r][c] = threadIdx_x
|
|
except:
|
|
pass
|
|
|
|
bank_conflicts = {}
|
|
|
|
print("\nTile layout (each number indicates the thread holding that position):")
|
|
for r in range(BASE_TILE_ROWS):
|
|
for c in range(BASE_TILE_COLS):
|
|
phase = shm_phase(inst, tile[r][c])
|
|
bank = shm_bank(inst, r, c)
|
|
color = color_code(phase)
|
|
bank_color = shm_bank_gradient(inst, bank)
|
|
|
|
if (bank, phase) not in bank_conflicts:
|
|
bank_conflicts[(bank, phase)] = []
|
|
bank_conflicts[(bank, phase)].append((r, c, tile[r][c]))
|
|
|
|
if phase == -1:
|
|
bank_color = (0, 0, 0)
|
|
|
|
text = colored(f"{tile[r][c]:2}", color)
|
|
text = rgb_bg(text, bank_color)
|
|
print(f"{text:2}", end=" ")
|
|
print()
|
|
|
|
for (bank, phase), positions in bank_conflicts.items():
|
|
if len(positions) > 1:
|
|
unique_threads = set(pos[2] for pos in positions)
|
|
if len(unique_threads) > 1:
|
|
print(f"{len(unique_threads)} way bank conflict: bank {bank}")
|
|
|
|
if __name__ == "__main__":
|
|
visualize_tile()
|
|
# visualize_threads()
|