mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
feat: initial tk library (#13160)
This commit is contained in:
1
extra/thunder/tiny/tk/__init__.py
Normal file
1
extra/thunder/tiny/tk/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
WARP_THREADS = 32
|
||||
272
extra/thunder/tiny/tk/group.py
Normal file
272
extra/thunder/tiny/tk/group.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import math, functools
|
||||
from typing import cast, Callable
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.tiles import TILE_ROW_DIM, TILE_COL_DIM, RT_BASE_TILE_NEPT, slots
|
||||
|
||||
class Group:
|
||||
def __init__(self, warps:int, ker):
|
||||
self.warps = warps
|
||||
self.group_threads = warps * WARP_THREADS
|
||||
self.threadIdx_x = ker.threadIdx_x
|
||||
self.ker = ker
|
||||
|
||||
# helpers
|
||||
@property
|
||||
def laneid(self): return self.threadIdx_x % self.group_threads
|
||||
@property
|
||||
def warpid(self): return self.laneid // WARP_THREADS
|
||||
@property
|
||||
def groupid(self): return self.threadIdx_x // self.group_threads
|
||||
|
||||
# ops that only work on a single warp
|
||||
|
||||
clear_rid = 1000
|
||||
def clear(self, reg:UOp, value:float=0):
|
||||
assert self.warps == 1
|
||||
|
||||
i = UOp.range(reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
return reg.reshape((reg.size,))[i].set(value, end=i).after(reg).reshape(reg.shape)
|
||||
|
||||
def zero(self, reg:UOp): return self.clear(reg, 0)
|
||||
def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf)
|
||||
|
||||
copy_rid = 300
|
||||
def copy(self, dst:UOp, src:UOp):
|
||||
assert self.warps == 1
|
||||
|
||||
assert dst.shape == src.shape
|
||||
assert cast(PtrDType, dst.dtype).addrspace == AddrSpace.REG
|
||||
assert cast(PtrDType, src.dtype).addrspace == AddrSpace.REG
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
|
||||
Group.copy_rid += len(dst.shape)
|
||||
|
||||
dst_store = dst[*rngs_for_shape].store(src[*rngs_for_shape].cast(dst.dtype.base)).end(*rngs_for_shape)
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
|
||||
mma_rid = 600
|
||||
def mma_AB(self, c:UOp, a:UOp, b:UOp, after=True):
|
||||
assert self.warps == 1
|
||||
|
||||
mma_i_height = UOp.range(c.shape[-3], Group.mma_rid)
|
||||
mma_i_width = UOp.range(c.shape[-2], Group.mma_rid+1)
|
||||
mma_i_inner = UOp.range(a.shape[-2], Group.mma_rid+2, AxisType.REDUCE)
|
||||
Group.mma_rid += 3
|
||||
|
||||
wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
|
||||
a_in = UOp.vectorize(*[a[mma_i_height, mma_i_inner, i] for i in range(8)])
|
||||
b_in1 = UOp.vectorize(*([b[mma_i_inner, mma_i_width, i] for i in range(2)] + [b[mma_i_inner, mma_i_width, 4+i] for i in range(2)]))
|
||||
c_out1 = UOp.vectorize(*[c[mma_i_height, mma_i_width, i] for i in range(4)])
|
||||
b_in2 = UOp.vectorize(*([b[mma_i_inner, mma_i_width, 2+i] for i in range(2)] + [b[mma_i_inner, mma_i_width, 6+i] for i in range(2)]))
|
||||
c_out2 = UOp.vectorize(*[c[mma_i_height, mma_i_width, 4+i] for i in range(4)])
|
||||
|
||||
out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg)
|
||||
out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg)
|
||||
c_i = [c[mma_i_height, mma_i_width, i].store(out1.gep(i)) for i in range(4)] + [c[mma_i_height, mma_i_width, 4+i].store(out2.gep(i)) for i in range(4)]
|
||||
c_store = UOp.group(*c_i).end(mma_i_height, mma_i_width, mma_i_inner)
|
||||
|
||||
self.ker.push_store(c_store, c)
|
||||
return c.after(c_store).reshape(c.shape) if after else c_store
|
||||
|
||||
def mma_ABt(self, c:UOp, a:UOp, b:UOp, after=True):
|
||||
assert self.warps == 1
|
||||
|
||||
mma_i_height = UOp.range(c.shape[-3], Group.mma_rid)
|
||||
mma_i_width = UOp.range(c.shape[-2], Group.mma_rid+1)
|
||||
mma_i_inner = UOp.range(a.shape[-2], Group.mma_rid+2, AxisType.REDUCE)
|
||||
Group.mma_rid += 3
|
||||
|
||||
wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ())
|
||||
|
||||
a_in = UOp.vectorize(*[a[mma_i_height, mma_i_inner, i] for i in range(8)])
|
||||
b_in1 = UOp.vectorize(*([b[mma_i_width, mma_i_inner, i] for i in range(2)] + [b[mma_i_width, mma_i_inner, 4+i] for i in range(2)]))
|
||||
c_out1 = UOp.vectorize(*[c[mma_i_height, mma_i_width, i] for i in range(4)])
|
||||
b_in2 = UOp.vectorize(*([b[mma_i_width, mma_i_inner, 2+i] for i in range(2)] + [b[mma_i_width, mma_i_inner, 6+i] for i in range(2)]))
|
||||
c_out2 = UOp.vectorize(*[c[mma_i_height, mma_i_width, 4+i] for i in range(4)])
|
||||
|
||||
out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg)
|
||||
out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg)
|
||||
c_i = [c[mma_i_height, mma_i_width, i].store(out1.gep(i)) for i in range(4)] + [c[mma_i_height, mma_i_width, 4+i].store(out2.gep(i)) for i in range(4)]
|
||||
c_store = UOp.group(*c_i).end(mma_i_height, mma_i_width, mma_i_inner)
|
||||
|
||||
self.ker.push_store(c_store, c)
|
||||
return c.after(c_store).reshape(c.shape) if after else c_store
|
||||
|
||||
map_rid = 400
|
||||
def map(self, a:UOp, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
|
||||
assert self.warps == 1
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape))
|
||||
Group.map_rid += len(a.shape)
|
||||
|
||||
if op.__code__.co_argcount == 1:
|
||||
to_store = op(a[*rngs_for_shape])
|
||||
else:
|
||||
to_store = op(a[*rngs_for_shape], rngs_for_shape)
|
||||
|
||||
a_store = a[*rngs_for_shape].store(to_store).end(*rngs_for_shape)
|
||||
|
||||
self.ker.push_store(a_store, a)
|
||||
return a.after(a_store).reshape(a.shape)
|
||||
|
||||
def row_reduce(self, vec:UOp, src:UOp, op:Callable[[UOp, UOp], UOp]):
|
||||
assert self.warps == 1
|
||||
|
||||
red_local = UOp.placeholder((self.group_threads, 2), src.dtype.base, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot)
|
||||
slots.shared_slot += 1
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
for width in self.ker.range(src.shape[-2], AxisType.REDUCE, track=False):
|
||||
for i_inner in self.ker.range(4, AxisType.REDUCE, track=False):
|
||||
elem_index = i_inner + 2 * (i_inner // 2) + i_outer * 2
|
||||
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], src[height, width, elem_index])).end(width, i_inner, i_outer)
|
||||
vec = vec.after(vec_store).reshape(vec.shape)
|
||||
|
||||
# store to shared memory
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
red_local_store = red_local[self.laneid, i_outer].store(vec[height, 0, i_outer]).end(i_outer)
|
||||
red_local = red_local.after(red_local_store).reshape(red_local.shape)
|
||||
|
||||
# reduce from shared memory
|
||||
for i_outer in self.ker.range(2, track=False):
|
||||
for i_inner in self.ker.range(3, AxisType.REDUCE, track=False):
|
||||
offset = (self.laneid // 4) * 4 + ((self.laneid + 1 + i_inner) % 4)
|
||||
vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], red_local[offset, i_outer])).end(i_inner, i_outer)
|
||||
|
||||
self.ker.push_store(vec_store, vec)
|
||||
return vec.after(vec_store).reshape(vec.shape)
|
||||
|
||||
# ops that can work across multiple warps
|
||||
|
||||
LOAD_INNER = 8
|
||||
load_rid = 100
|
||||
def load(self, dst:UOp, src:UOp, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL:
|
||||
srcf = src.flatten(-2)
|
||||
|
||||
load_i_height = UOp.range(dst.shape[-3], Group.load_rid)
|
||||
load_i_width = UOp.range(dst.shape[-2], Group.load_rid+1)
|
||||
load_i_inner = UOp.range(RT_BASE_TILE_NEPT, Group.load_rid+2)
|
||||
Group.load_rid += 3
|
||||
|
||||
if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4)
|
||||
else: local_warpid = self.warpid
|
||||
warp_laneid = self.threadIdx_x % WARP_THREADS
|
||||
|
||||
if not transpose:
|
||||
row = (local_warpid * dst.shape[-3] + load_i_height) * TILE_ROW_DIM + (warp_laneid // 4)
|
||||
col = load_i_width * TILE_COL_DIM + 2 * (warp_laneid % 4)
|
||||
|
||||
row_offset = ((load_i_inner % 4) // 2) * 8
|
||||
col_offset = (load_i_inner % 2) + (load_i_inner // 4) * 8
|
||||
else:
|
||||
row = (local_warpid * dst.shape[-3] + load_i_height) * TILE_ROW_DIM + 2 * (warp_laneid % 4)
|
||||
col = load_i_width * TILE_COL_DIM + (warp_laneid // 4)
|
||||
|
||||
row_offset = (load_i_inner % 2) + (load_i_inner // 4) * 8
|
||||
col_offset = ((load_i_inner % 4) // 2) * 8
|
||||
|
||||
src_i_last = (row + row_offset) * src.shape[-1] + col + col_offset
|
||||
|
||||
dst_store = dst[*dst_idxs, load_i_height, load_i_width, load_i_inner].store(srcf[*idxs[:-2], src_i_last])
|
||||
dst_store = dst_store.end(load_i_height, load_i_width, load_i_inner)
|
||||
elif dst_dtype.addrspace == AddrSpace.LOCAL and src_dtype.addrspace == AddrSpace.GLOBAL:
|
||||
dstf = dst.flatten(-2)
|
||||
|
||||
srcf = src.flatten()
|
||||
row_stride = prod(src.shape[axis+1:])
|
||||
|
||||
idxs = tuple(idx * dst.shape[-2] if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * dst.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
|
||||
|
||||
memcpy_per_row = dst.shape[-1] // Group.LOAD_INNER
|
||||
total_calls = prod(dst.shape[-2:]) // (self.group_threads * Group.LOAD_INNER)
|
||||
|
||||
load_i_outer = UOp.range(total_calls, Group.load_rid)
|
||||
load_i_inner = UOp.range(Group.LOAD_INNER, Group.load_rid+1)
|
||||
Group.load_rid += 2
|
||||
|
||||
load_idx = load_i_outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * Group.LOAD_INNER) % dst.shape[-1]
|
||||
|
||||
dst_i = row * dst.shape[-1] + col + load_i_inner
|
||||
src_i += row * row_stride + col + load_i_inner
|
||||
|
||||
dst_store = dstf[*dst_idxs, dst_i].store(srcf[src_i]).end(load_i_outer, load_i_inner)
|
||||
else:
|
||||
raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
|
||||
|
||||
return dst.after(dst_store.barrier()).reshape(dst.shape)
|
||||
|
||||
STORE_INNER = 8
|
||||
store_rid = 200
|
||||
def store(self, dst:UOp, src:UOp, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True):
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
|
||||
dstf = dst.flatten(-2)
|
||||
|
||||
store_i_height = UOp.range(src.shape[-3], Group.store_rid)
|
||||
store_i_width = UOp.range(src.shape[-2], Group.store_rid+1)
|
||||
store_i_inner = UOp.range(RT_BASE_TILE_NEPT, Group.store_rid+2)
|
||||
Group.store_rid += 3
|
||||
|
||||
if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4)
|
||||
else: local_warpid = self.warpid
|
||||
warp_laneid = self.threadIdx_x % WARP_THREADS
|
||||
|
||||
row = (local_warpid * src.shape[-3] + store_i_height) * TILE_ROW_DIM + (warp_laneid // 4)
|
||||
col = store_i_width * TILE_COL_DIM + 2 * (warp_laneid % 4)
|
||||
|
||||
row_offset = ((store_i_inner % 4) // 2) * 8
|
||||
col_offset = (store_i_inner % 2) + (store_i_inner // 4) * 8
|
||||
|
||||
dst_i_last = (row + row_offset) * dst.shape[-1] + col + col_offset
|
||||
|
||||
dst_store = dstf[*idxs[:-2], dst_i_last].store(src[*src_idxs, store_i_height, store_i_width, store_i_inner])
|
||||
dst_store = dst_store.end(store_i_height, store_i_width, store_i_inner)
|
||||
elif src_dtype.addrspace == AddrSpace.LOCAL and dst_dtype.addrspace == AddrSpace.GLOBAL:
|
||||
dstf = dst.flatten()
|
||||
row_stride = prod(dst.shape[axis+1:])
|
||||
|
||||
idxs = tuple(idx * src.shape[-2] if i == axis else idx for i, idx in enumerate(idxs))
|
||||
idxs = tuple(idx * src.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3]
|
||||
|
||||
srcf = src.flatten(-2)
|
||||
|
||||
memcpy_per_row = src.shape[-1] // Group.STORE_INNER
|
||||
total_calls = prod(src.shape[-2:]) // (self.group_threads * Group.STORE_INNER)
|
||||
|
||||
store_i_outer = UOp.range(total_calls, Group.store_rid)
|
||||
store_i_inner = UOp.range(Group.STORE_INNER, Group.store_rid+1)
|
||||
Group.store_rid += 2
|
||||
|
||||
load_idx = store_i_outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * Group.STORE_INNER) % src.shape[-1]
|
||||
|
||||
src_i = row * src.shape[-1] + col + store_i_inner
|
||||
dst_i += row * row_stride + col + store_i_inner
|
||||
|
||||
dst_store = dstf[dst_i].store(srcf[*src_idxs, src_i]).end(store_i_outer, store_i_inner)
|
||||
else:
|
||||
raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store.barrier()).reshape(dst.shape) if after else dst_store
|
||||
57
extra/thunder/tiny/tk/kernel.py
Normal file
57
extra/thunder/tiny/tk/kernel.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from contextlib import AbstractContextManager
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, AxisType
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.group import Group
|
||||
|
||||
class _tk_range:
|
||||
user_rid = 0
|
||||
def __init__(self, end:int, axis_type:AxisType): self.end, self.axis_type, self.done = end, axis_type, False
|
||||
def __iter__(self): return self
|
||||
def __next__(self):
|
||||
if not self.done:
|
||||
self.done = True
|
||||
_tk_range.user_rid += 1
|
||||
self._rng = UOp.range(self.end, _tk_range.user_rid-1, axis_type=self.axis_type)
|
||||
return self._rng
|
||||
raise StopIteration
|
||||
|
||||
class Kernel(AbstractContextManager):
|
||||
def __init__(self, grid_size:tuple[int, int, int], block_size:int):
|
||||
self.blockIdx_x = UOp.special(grid_size[0], "gidx0")
|
||||
self.blockIdx_y = UOp.special(grid_size[1], "gidx1")
|
||||
self.blockIdx_z = UOp.special(grid_size[2], "gidx2")
|
||||
self.threadIdx_x = UOp.special(block_size, "lidx0")
|
||||
|
||||
self.range_stack = []
|
||||
self.store_stack = []
|
||||
|
||||
@property
|
||||
def warpid(self): return self.threadIdx_x // WARP_THREADS
|
||||
|
||||
def __enter__(self): return self
|
||||
def __exit__(self, exc_type, exc_value, traceback): pass
|
||||
|
||||
def group(self, size:int): return Group(size, self)
|
||||
@property
|
||||
def warp(self): return self.group(1)
|
||||
@property
|
||||
def warpgroup(self): return self.group(4)
|
||||
|
||||
def range(self, end:int, axis_type:AxisType=AxisType.LOOP, track:bool=True):
|
||||
rng = _tk_range(end, axis_type)
|
||||
if track: self.range_stack.append(rng)
|
||||
return rng
|
||||
|
||||
def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop))
|
||||
|
||||
def finish(self):
|
||||
# end all ranges
|
||||
rngs = []
|
||||
while self.range_stack: rngs.append(self.range_stack.pop(0)._rng)
|
||||
|
||||
return self.store_stack.pop()[0].end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
def endrange(self):
|
||||
last_store = self.store_stack.pop()
|
||||
last_range = self.range_stack.pop()
|
||||
return last_store[1].after(last_store[0].barrier().end(last_range._rng)).reshape(last_store[1].shape)
|
||||
52
extra/thunder/tiny/tk/tiles.py
Normal file
52
extra/thunder/tiny/tk/tiles.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import math
|
||||
from typing import cast, Callable
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
|
||||
class _Slots:
|
||||
def __init__(self):
|
||||
self.global_slot = 0
|
||||
self.shared_slot = 0
|
||||
self.register_slot = 0
|
||||
slots = _Slots()
|
||||
|
||||
def gl(shape, dtype):
|
||||
slots.global_slot += 1
|
||||
return UOp.placeholder(shape, dtype, slot=slots.global_slot-1)
|
||||
|
||||
shared_slot = 0
|
||||
def st(shape, dtype):
|
||||
slots.shared_slot += 1
|
||||
return UOp.placeholder(shape, dtype, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot-1)
|
||||
|
||||
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
|
||||
RT_BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
|
||||
RT_BASE_TILE_NEPT = RT_BASE_TILE_NE // WARP_THREADS
|
||||
register_slot = 0
|
||||
def rt(shape, dtype):
|
||||
assert len(shape) == 2
|
||||
|
||||
height = shape[0] // TILE_ROW_DIM
|
||||
width = shape[1] // TILE_COL_DIM
|
||||
|
||||
slots.register_slot += 1
|
||||
return UOp.placeholder((height, width, RT_BASE_TILE_NEPT), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1)
|
||||
|
||||
def rv(length, dtype, layout="naive"):
|
||||
tiles = length // TILE_ROW_DIM
|
||||
match layout:
|
||||
case "naive":
|
||||
inner_dim = 1
|
||||
outer_dim = (tiles + 1) // 2
|
||||
case "ortho":
|
||||
inner_dim = 1
|
||||
outer_dim = tiles
|
||||
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
|
||||
|
||||
slots.register_slot += 1
|
||||
return UOp.placeholder((outer_dim, inner_dim, 2), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1)
|
||||
345
test/external/external_test_tk.py
vendored
Normal file
345
test/external/external_test_tk.py
vendored
Normal file
@@ -0,0 +1,345 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.kernel import Kernel
|
||||
from extra.thunder.tiny.tk.tiles import gl, st, rt, rv
|
||||
|
||||
class TestTK(unittest.TestCase):
|
||||
@unittest.skip("store from float rt is wrong")
|
||||
def test_simple_matmul(self):
|
||||
N = 32
|
||||
BLOCK_SIZE = 16
|
||||
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
c = gl((1, 1, N, N), dtypes.float32)
|
||||
a = gl((1, 1, N, N), dtypes.bfloat16)
|
||||
b = gl((1, 1, N, N), dtypes.bfloat16)
|
||||
|
||||
a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
c_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
c_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
col, row = ker.blockIdx_x, ker.blockIdx_y
|
||||
|
||||
c_reg = warp.zero(c_reg)
|
||||
for tile in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2)
|
||||
b_smem = warp.load(b_smem, b, (), (0, 0, tile, col), axis=2)
|
||||
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
b_reg = warp.load(b_reg, b_smem, transpose=True)
|
||||
|
||||
c_reg = warp.mma_AB(c_reg, a_reg, b_reg)
|
||||
c_reg = ker.endrange()
|
||||
|
||||
c_smem = warp.store(c_smem, c_reg)
|
||||
c = warp.store(c, c_smem, (0, 0, row, col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous()
|
||||
b = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous()
|
||||
c = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
c = c.float()
|
||||
|
||||
ref = a.matmul(b, dtype=dtypes.float32).float()
|
||||
|
||||
assert ref.allclose(c)
|
||||
|
||||
@unittest.skip("store from float rt is wrong")
|
||||
def test_simple_matmul_transposed(self):
|
||||
N = 32
|
||||
BLOCK_SIZE = 16
|
||||
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
c = gl((1, 1, N, N), dtypes.float32)
|
||||
a = gl((1, 1, N, N), dtypes.bfloat16)
|
||||
b = gl((1, 1, N, N), dtypes.bfloat16)
|
||||
|
||||
a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
c_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
c_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
col, row = ker.blockIdx_x, ker.blockIdx_y
|
||||
|
||||
c_reg = warp.zero(c_reg)
|
||||
for tile in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2)
|
||||
b_smem = warp.load(b_smem, b, (), (0, 0, col, tile), axis=2)
|
||||
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
b_reg = warp.load(b_reg, b_smem)
|
||||
|
||||
c_reg = warp.mma_ABt(c_reg, a_reg, b_reg)
|
||||
c_reg = ker.endrange()
|
||||
|
||||
c_smem = warp.store(c_smem, c_reg)
|
||||
c = warp.store(c, c_smem, (0, 0, row, col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous()
|
||||
b = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous()
|
||||
c = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
c = c.float()
|
||||
|
||||
ref = a.matmul(b.transpose(2, 3), dtype=dtypes.float32).float()
|
||||
|
||||
assert ref.allclose(c)
|
||||
|
||||
def test_load_store(self):
|
||||
N = 32
|
||||
BLOCK_SIZE = 16
|
||||
with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = gl((1, 1, N, N), dtypes.float32)
|
||||
a = gl((1, 1, N, N), dtypes.float32)
|
||||
|
||||
a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
col, row = ker.blockIdx_x, ker.blockIdx_y
|
||||
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, row, col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
b_reg = warp.copy(b_reg, a_reg)
|
||||
b_smem = warp.store(b_smem, b_reg)
|
||||
b = warp.store(b, b_smem, (0, 0, row, col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous()
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
ref = a.float()
|
||||
|
||||
assert ref.allclose(b)
|
||||
|
||||
def test_max(self):
|
||||
N = 16
|
||||
BLOCK_SIZE = 16
|
||||
with Kernel((1, 1, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = gl((1, 1, N, N), dtypes.float32)
|
||||
a = gl((1, 1, N, N), dtypes.float32)
|
||||
|
||||
a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
max_reg = rv(BLOCK_SIZE, dtypes.float32, "ortho")
|
||||
|
||||
max_reg = warp.neg_inf(max_reg)
|
||||
|
||||
for tile_row in ker.range(N // BLOCK_SIZE):
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
max_reg = warp.row_reduce(max_reg, a_reg, lambda a, b: a.maximum(b))
|
||||
sum_reg = ker.endrange()
|
||||
|
||||
b_reg = warp.zero(b_reg).after(tile_row)
|
||||
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2])
|
||||
b_smem = warp.store(b_smem, b_reg)
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous()
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
ref = a.float().max(axis=3, keepdim=True).expand(a.shape)
|
||||
|
||||
assert ref.allclose(b)
|
||||
|
||||
def test_max_nonsquare(self):
|
||||
N, M = 16, 64
|
||||
BLOCK_N, BLOCK_M = 16, 64
|
||||
with Kernel((1, 1, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = gl((1, 1, N, M), dtypes.float32)
|
||||
a = gl((1, 1, N, M), dtypes.float32)
|
||||
|
||||
a_smem = st((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
b_smem = st((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
b_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
|
||||
max_reg = rv(BLOCK_N, dtypes.float32, "ortho")
|
||||
|
||||
max_reg = warp.zero(max_reg)
|
||||
|
||||
for tile_row in ker.range(N // BLOCK_N):
|
||||
for tile_col in ker.range(M // BLOCK_M):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
sum_reg = warp.row_reduce(max_reg, a_reg, lambda a, b: a.maximum(b))
|
||||
sum_reg = ker.endrange()
|
||||
|
||||
b_reg = warp.zero(b_reg).after(tile_row)
|
||||
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2])
|
||||
b_smem = warp.store(b_smem, b_reg)
|
||||
|
||||
for tile_col in ker.range(M // BLOCK_M):
|
||||
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, M, dtype="float32").contiguous()
|
||||
b = Tensor.empty(1, 1, N, M, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
ref = a.float().max(axis=3, keepdim=True).expand(a.shape)
|
||||
|
||||
assert ref.allclose(b)
|
||||
|
||||
def test_sum(self):
|
||||
N = 16
|
||||
BLOCK_SIZE = 16
|
||||
with Kernel((1, 1, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = gl((1, 1, N, N), dtypes.float32)
|
||||
a = gl((1, 1, N, N), dtypes.float32)
|
||||
|
||||
a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
sum_reg = rv(BLOCK_SIZE, dtypes.float32, "ortho")
|
||||
|
||||
for tile_row in ker.range(N // BLOCK_SIZE):
|
||||
sum_reg = warp.zero(sum_reg).after(tile_row)
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
sum_reg = warp.row_reduce(sum_reg, a_reg, lambda a, b: a + b)
|
||||
sum_reg = ker.endrange()
|
||||
|
||||
b_reg = warp.zero(b_reg).after(tile_row)
|
||||
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2])
|
||||
b_smem = warp.store(b_smem, b_reg)
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous()
|
||||
a = Tensor.arange(1 * 1 * N * N).reshape(1, 1, N, N).cast(dtypes.float32).contiguous()
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
ref = a.float().sum(axis=3, keepdim=True).expand(a.shape)
|
||||
|
||||
assert ref.allclose(b)
|
||||
|
||||
def test_sum_nonsquare(self):
|
||||
N, M = 16, 64
|
||||
BLOCK_N, BLOCK_M = 16, 64
|
||||
with Kernel((1, 1, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = gl((1, 1, N, M), dtypes.float32)
|
||||
a = gl((1, 1, N, M), dtypes.float32)
|
||||
|
||||
a_smem = st((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
b_smem = st((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
b_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32)
|
||||
|
||||
sum_reg = rv(BLOCK_N, dtypes.float32, "ortho")
|
||||
|
||||
sum_reg = warp.zero(sum_reg)
|
||||
|
||||
for tile_row in ker.range(N // BLOCK_N):
|
||||
for tile_col in ker.range(M // BLOCK_M):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
sum_reg = warp.row_reduce(sum_reg, a_reg, lambda a, b: a + b)
|
||||
sum_reg = ker.endrange()
|
||||
|
||||
b_reg = warp.zero(b_reg).after(tile_row)
|
||||
b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2])
|
||||
b_smem = warp.store(b_smem, b_reg)
|
||||
|
||||
for tile_col in ker.range(M // BLOCK_M):
|
||||
b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, M, dtype="float32").contiguous()
|
||||
b = Tensor.empty(1, 1, N, M, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
ref = a.float().sum(axis=3, keepdim=True).expand(a.shape)
|
||||
|
||||
assert ref.allclose(b)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user