move that out

This commit is contained in:
George Hotz
2025-10-31 09:12:38 +08:00
parent ba1d1142be
commit 2543ce7585

View File

@@ -7,6 +7,10 @@ from tinygrad.helpers import getenv
N = 4096
run_count = 5
# ---------------------------
# launch/config constants
# ---------------------------
# block for locals
BN = 128
BM = 128
@@ -16,42 +20,33 @@ BK = 8
TN = 4
TM = 4
is_kernel5 = getenv("K5", 0)
BLOCK_SIZE = 128 if is_kernel5 else 256
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
assert (BN * BK) % BLOCK_SIZE == 0
assert (BM * BK) % BLOCK_SIZE == 0
def hand_spec_kernel3(kernel5=getenv("K5", 0)):
# ---------------------------
# launch/config constants
# ---------------------------
nbWaves = BLOCK_SIZE // 32
WN = 128 if is_kernel5 else 64
WM = BN * BM // nbWaves // WN
assert BN % WN == 0, "BN must be a multiple of WN"
assert BM % WM == 0, "BM must be a multiple of WM"
nbWaveX = BN // WN
nbWaveY = BM // WM
BLOCK_SIZE = 128 if kernel5 else 256
nbWaves = BLOCK_SIZE // 32
WN = 128 if kernel5 else 64
WM = BN * BM // nbWaves // WN
# Sanity checks (fail fast if shapes/tiles misalign)
assert BN % WN == 0, "BN must be a multiple of WN"
assert BM % WM == 0, "BM must be a multiple of WM"
nbWaveX = BN // WN
nbWaveY = BM // WM
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
assert (BN * BK) % BLOCK_SIZE == 0
assert (BM * BK) % BLOCK_SIZE == 0
nbThreadXPerWave = 8
nbThreadYPerWave = 4
nbIterWaveN = WN // (nbThreadXPerWave * TN)
nbIterWaveM = WM // (nbThreadYPerWave * TM)
SUBWN = WN // nbIterWaveN
SUBWM = WM // nbIterWaveM
nbThreadXPerWave = 8
nbThreadYPerWave = 4
nbIterWaveN = WN // (nbThreadXPerWave * TN)
nbIterWaveM = WM // (nbThreadYPerWave * TM)
SUBWN = WN // nbIterWaveN
SUBWM = WM // nbIterWaveM
def hand_spec_kernel3():
# ---------------------------
# per-thread read mapping
# ---------------------------
# A: read BK x BN tiles; B: read BN x BK tiles
threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0")
waveIdx = (threadIdx_x // 32) % nbWaveX
@@ -70,7 +65,7 @@ def hand_spec_kernel3(kernel5=getenv("K5", 0)):
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
BM_As_stride = (BM + 4) if kernel5 else BM
BM_As_stride = (BM + 4) if is_kernel5 else BM
As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL)