mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
index slicing + allclose (#13071)
* continue work on slicing+allclose
* Revert "Revert "slicing + allclose""
This reverts commit 6c7a12f21c.
* fix tests + better syntax
* forgot an after
* slot is an integer
This commit is contained in:
@@ -5,6 +5,7 @@ from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
N = 4096
|
||||
M = K = N
|
||||
run_count = 5
|
||||
|
||||
# ---------------------------
|
||||
@@ -81,26 +82,26 @@ def hand_spec_kernel3():
|
||||
c_regs = UOp.placeholder((ITERS_PER_WAVE_M, TM, ITERS_PER_WAVE_N, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
i = UOp.range(c_regs.size, 16)
|
||||
c_regs = c_regs[i].set(0.0, end=i)
|
||||
c_regs = c_regs.after(c_regs.flatten()[i].store(UOp.const(dtypes.float, 0.0)).end(i))
|
||||
|
||||
# pre-index the global tensors based on the global ranges
|
||||
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[blockIdx_y, :, blockIdx_x, :]
|
||||
k_tile_range = UOp.range(N // BLOCK_K, 0)
|
||||
a = a.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_K, BLOCK_K)[blockIdx_y, :, k_tile_range, :]
|
||||
b = b.reshape(N // BLOCK_K, BLOCK_K, N // BLOCK_N, BLOCK_N)[k_tile_range, :, blockIdx_x, :]
|
||||
|
||||
# ---------------------------
|
||||
# GLOBAL -> LOCAL (As, Bs)
|
||||
# ---------------------------
|
||||
b = b.reshape(N // BLOCK_K, BLOCK_K,
|
||||
N // BLOCK_N, BLOCK_N)
|
||||
i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1)
|
||||
index_x = tid % BLOCK_N
|
||||
index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i
|
||||
Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i)
|
||||
Bs_store = Bs[index_y, index_x].store(b[index_y, index_x]).end(i)
|
||||
|
||||
a = a.reshape(N // BLOCK_M, BLOCK_M,
|
||||
N // BLOCK_K, BLOCK_K)
|
||||
i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2)
|
||||
index_x = tid % BLOCK_K
|
||||
index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i
|
||||
As_store = As[index_x, index_y].store(a[blockIdx_y, index_y, k_tile_range, index_x]).end(i)
|
||||
As_store = As[index_x, index_y].store(a[index_y, index_x]).end(i)
|
||||
|
||||
# TODO: can we automate barrier?
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
@@ -139,13 +140,12 @@ def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# REG -> GLOBAL (epilogue)
|
||||
# ---------------------------
|
||||
c = c.reshape(N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||
c = c.reshape(WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000)
|
||||
yt = UOp.range(TM, 1001)
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002)
|
||||
xt = UOp.range(TN, 1003)
|
||||
c_glbl_idx = c[blockIdx_y, waveIdy, iterWaveM, idyInWave, yt, blockIdx_x, waveIdx, iterWaveN, idxInWave, xt]
|
||||
c_glbl_idx = c[waveIdy, iterWaveM, idyInWave, yt, waveIdx, iterWaveN, idxInWave, xt]
|
||||
sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
|
||||
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
|
||||
@@ -37,11 +37,11 @@ WARPGROUP_SIZE = 1
|
||||
BLOCK_M = BLOCK_M * WARPGROUP_SIZE
|
||||
|
||||
# TODO: improve the syntax of this. better syntax, faster iteration
|
||||
# -- add working slice a[gx, :, i] -> shape of the : (aka (16,16,32) becomes (16,))
|
||||
# -- add argfix to movement (traits shared with Tensor)
|
||||
# -- DONE: add working slice a[gx, :, i] -> shape of the : (aka (16,16,32) becomes (16,))
|
||||
# -- DONE(ish): add argfix to movement (traits shared with Tensor)
|
||||
# -- fix WMMA to not require all the junk
|
||||
# -- improve syntax for vectorized loads/stores (both with DEVECTORIZE and without)
|
||||
# -- be able to use CONTRACT on a range
|
||||
# -- DONE: be able to use CONTRACT on a range
|
||||
# -- fix upcasted RANGE on an already vectorized buffer
|
||||
# -- improve "all ranges not ended error" / fix the bug with after on ended ranges (if you are after end of range, range is closed)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user