mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
progress
This commit is contained in:
@@ -60,140 +60,60 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
blockIdx_x = UOp.special(N//BN, "gidx0")
|
||||
blockIdx_y = UOp.special(N//BM, "gidx1")
|
||||
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
|
||||
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||
a = UOp.placeholder(dtypes.float, (N, N), slot=1)
|
||||
b = UOp.placeholder(dtypes.float, (N, N), slot=2)
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
|
||||
BM_As_stride = (BM+4) if kernel5 else BM
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM_As_stride, AddrSpace.LOCAL), arg=0)
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||
As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
|
||||
Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||
A_col = UOp.placeholder(dtypes.float, (nbIterWaveM, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
B_row = UOp.placeholder(dtypes.float, (nbIterWaveN, TN), slot=1, addrspace=AddrSpace.REG)
|
||||
c_regs = UOp.placeholder(dtypes.float, (TM * nbIterWaveM * TN * nbIterWaveN,), slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
i = UOp.range(c_regs.dtype.size, 16)
|
||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0)).end(i)
|
||||
c_regs = c_regs[i].set(0.0, end=i)
|
||||
|
||||
if kernel4:
|
||||
regA = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsA, AddrSpace.REG), arg=3)
|
||||
regB = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsB, AddrSpace.REG), arg=4)
|
||||
kId_range = UOp.range(N//BK, 0)
|
||||
kId = kId_range*BK
|
||||
|
||||
# initial load from globals into locals (0)
|
||||
kId = 0
|
||||
# load from globals into locals
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[index_y % BK, index_x % BN].store(b[index_y, index_x]).end(i)
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(nbReadsB, 0)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x], i)
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[index_x % BK, index_y % BM].store(a[index_y, index_x]).end(i)
|
||||
|
||||
i = UOp.range(nbReadsA, 1)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x], i)
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
# iterate over the middle chunk
|
||||
kId_range = UOp.range(N//BK-1, 2)
|
||||
kId = kId_range*BK
|
||||
k = UOp.range(BK, 3)
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave, i].store(Bs.after(barrier)[k, index]).end(iterWave, i)
|
||||
|
||||
# load from globals into registers (next round)
|
||||
i = UOp.range(nbReadsB, 3)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
regB_store = regB[i].store(b[N * index_y + index_x], i)
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave, i].store(As.after(barrier)[k, index]).end(iterWave, i)
|
||||
|
||||
i = UOp.range(nbReadsA, 4)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
regA_store = regA[i].store(a[N * index_y + index_x], i)
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
|
||||
def inner_loop(first_range, inp_dep=()):
|
||||
# inner unroll
|
||||
k = UOp.range(BK, first_range+0)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(nbIterWaveN, first_range+1)
|
||||
i = UOp.range(TN, first_range+2)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].after(*inp_dep), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(nbIterWaveM, first_range+3)
|
||||
i = UOp.range(TM, first_range+4)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].after(*inp_dep), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(nbIterWaveM, first_range+5)
|
||||
yt = UOp.range(TM, first_range+6)
|
||||
iterWaveN = UOp.range(nbIterWaveN, first_range+7)
|
||||
xt = UOp.range(TN, first_range+8)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
# sketchy, this should end the kId_range but it doesn't
|
||||
sink = c_regs_idx.store(c_regs_idx.after(init_store) + A_col[y].after(A_col_store) * B_row[x].after(B_row_store)).end(iterWaveM, iterWaveN, yt, xt, k)
|
||||
return sink
|
||||
|
||||
# TODO: kId_range should endrange after a barrier
|
||||
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
||||
|
||||
# load from registers into locals
|
||||
i = UOp.range(nbReadsB, 14)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
||||
|
||||
i = UOp.range(nbReadsA, 15)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
||||
|
||||
# final iteration without the copy
|
||||
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
||||
else:
|
||||
kId_range = UOp.range(N//BK, 0)
|
||||
kId = kId_range*BK
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x]).end(i)
|
||||
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x]).end(i)
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
k = UOp.range(BK, 3)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs.after(barrier)[k*BN + index]).end(iterWave, i)
|
||||
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As.after(barrier)[k*BM_As_stride + index]).end(iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
|
||||
gemm = c_regs.after(init_store)[y * TN * nbIterWaveN + x] + A_col.after(A_col_store)[y] * B_row.after(B_row_store)[x]
|
||||
sink = c_regs[y * TN * nbIterWaveN + x].store(gemm).end(iterWaveM, iterWaveN, yt, xt, k).barrier().end(kId_range)
|
||||
gemm = c_regs[y * TN * nbIterWaveN + x] + A_col.after(A_col_store)[y] * B_row.after(B_row_store)[x]
|
||||
sink = c_regs[y * TN * nbIterWaveN + x].store(gemm).end(iterWaveM, iterWaveN, yt, xt, k).barrier().end(kId_range)
|
||||
|
||||
# store c_regs into c
|
||||
iterWaveM = UOp.range(nbIterWaveM, 1000)
|
||||
@@ -202,8 +122,8 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
xt = UOp.range(TN, 1003)
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
indexC = N * (yOut + yt) + xOut + xt
|
||||
sink = c[indexC].store(c_regs.after(sink)[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)]).end(iterWaveM, iterWaveN, yt, xt)
|
||||
sink = c[yOut + yt, xOut + xt].store(c_regs.after(sink)[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)])
|
||||
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
return sink.sink(arg=KernelInfo(name="tinygemm", opts_to_apply=()))
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
|
||||
from tinygrad.helpers import strip_parens, colored
|
||||
@@ -370,7 +370,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if len(src) == 0: return self
|
||||
return UOp(Ops.END, src=(self,)+src)
|
||||
def after(self, *src:UOp, **kwargs):
|
||||
assert self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER, Ops.AFTER}, f"after can't be placed on {self.op}"
|
||||
assert self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER, Ops.AFTER, Ops.RESHAPE}, f"after can't be placed on {self.op}"
|
||||
return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
@@ -758,8 +758,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# *** uop high level syntactic sugar ***
|
||||
|
||||
@staticmethod
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot)
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
|
||||
@@ -156,6 +156,7 @@ kernel_spec = PatternMatcher([
|
||||
# RESHAPE (but only RESHAPE) is allowed here
|
||||
(UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.VCONST, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user