Files
tinygrad/extra/thunder/tiny/tk/tiles.py
2025-11-09 22:54:29 -08:00

53 lines
1.6 KiB
Python

import math
from typing import cast, Callable
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.helpers import getenv, prod
from extra.thunder.tiny.tk import WARP_THREADS
class _Slots:
def __init__(self):
self.global_slot = 0
self.shared_slot = 0
self.register_slot = 0
slots = _Slots()
def gl(shape, dtype):
slots.global_slot += 1
return UOp.placeholder(shape, dtype, slot=slots.global_slot-1)
shared_slot = 0
def st(shape, dtype):
slots.shared_slot += 1
return UOp.placeholder(shape, dtype, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot-1)
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
RT_BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
RT_BASE_TILE_NEPT = RT_BASE_TILE_NE // WARP_THREADS
register_slot = 0
def rt(shape, dtype):
assert len(shape) == 2
height = shape[0] // TILE_ROW_DIM
width = shape[1] // TILE_COL_DIM
slots.register_slot += 1
return UOp.placeholder((height, width, RT_BASE_TILE_NEPT), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1)
def rv(length, dtype, layout="naive"):
tiles = length // TILE_ROW_DIM
match layout:
case "naive":
inner_dim = 1
outer_dim = (tiles + 1) // 2
case "ortho":
inner_dim = 1
outer_dim = tiles
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
slots.register_slot += 1
return UOp.placeholder((outer_dim, inner_dim, 2), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1)