mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move that out
This commit is contained in:
@@ -7,6 +7,10 @@ from tinygrad.helpers import getenv
|
||||
N = 4096
|
||||
run_count = 5
|
||||
|
||||
# ---------------------------
|
||||
# launch/config constants
|
||||
# ---------------------------
|
||||
|
||||
# block for locals
|
||||
BN = 128
|
||||
BM = 128
|
||||
@@ -16,42 +20,33 @@ BK = 8
|
||||
TN = 4
|
||||
TM = 4
|
||||
|
||||
is_kernel5 = getenv("K5", 0)
|
||||
BLOCK_SIZE = 128 if is_kernel5 else 256
|
||||
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
|
||||
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
|
||||
assert (BN * BK) % BLOCK_SIZE == 0
|
||||
assert (BM * BK) % BLOCK_SIZE == 0
|
||||
|
||||
def hand_spec_kernel3(kernel5=getenv("K5", 0)):
|
||||
# ---------------------------
|
||||
# launch/config constants
|
||||
# ---------------------------
|
||||
nbWaves = BLOCK_SIZE // 32
|
||||
WN = 128 if is_kernel5 else 64
|
||||
WM = BN * BM // nbWaves // WN
|
||||
assert BN % WN == 0, "BN must be a multiple of WN"
|
||||
assert BM % WM == 0, "BM must be a multiple of WM"
|
||||
nbWaveX = BN // WN
|
||||
nbWaveY = BM // WM
|
||||
|
||||
BLOCK_SIZE = 128 if kernel5 else 256
|
||||
|
||||
nbWaves = BLOCK_SIZE // 32
|
||||
WN = 128 if kernel5 else 64
|
||||
WM = BN * BM // nbWaves // WN
|
||||
|
||||
# Sanity checks (fail fast if shapes/tiles misalign)
|
||||
assert BN % WN == 0, "BN must be a multiple of WN"
|
||||
assert BM % WM == 0, "BM must be a multiple of WM"
|
||||
nbWaveX = BN // WN
|
||||
nbWaveY = BM // WM
|
||||
|
||||
assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN"
|
||||
assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK"
|
||||
|
||||
assert (BN * BK) % BLOCK_SIZE == 0
|
||||
assert (BM * BK) % BLOCK_SIZE == 0
|
||||
|
||||
nbThreadXPerWave = 8
|
||||
nbThreadYPerWave = 4
|
||||
nbIterWaveN = WN // (nbThreadXPerWave * TN)
|
||||
nbIterWaveM = WM // (nbThreadYPerWave * TM)
|
||||
SUBWN = WN // nbIterWaveN
|
||||
SUBWM = WM // nbIterWaveM
|
||||
nbThreadXPerWave = 8
|
||||
nbThreadYPerWave = 4
|
||||
nbIterWaveN = WN // (nbThreadXPerWave * TN)
|
||||
nbIterWaveM = WM // (nbThreadYPerWave * TM)
|
||||
SUBWN = WN // nbIterWaveN
|
||||
SUBWM = WM // nbIterWaveM
|
||||
|
||||
def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# per-thread read mapping
|
||||
# ---------------------------
|
||||
# A: read BK x BN tiles; B: read BN x BK tiles
|
||||
|
||||
threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0")
|
||||
|
||||
waveIdx = (threadIdx_x // 32) % nbWaveX
|
||||
@@ -70,7 +65,7 @@ def hand_spec_kernel3(kernel5=getenv("K5", 0)):
|
||||
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
|
||||
BM_As_stride = (BM + 4) if kernel5 else BM
|
||||
BM_As_stride = (BM + 4) if is_kernel5 else BM
|
||||
As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
|
||||
Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user