tk: named kernels + per kernel range id (#13522)

This commit is contained in:
wozeparrot
2025-12-01 22:51:04 -08:00
committed by GitHub
parent 8713ae6de9
commit 1b7dbfb37f
2 changed files with 23 additions and 21 deletions

View File

@@ -5,21 +5,21 @@ from extra.thunder.tiny.tk.group import Group
from extra.thunder.tiny.tk.tiles import GL, ST_16X16, ST_16X16_SWIZZLED, ST, RT_16X16, RT, RV, TileLayout, VecLayout
class _tk_range:
user_rid = 0
def __init__(self, start:int, end:int, step:int, axis_type:AxisType):
def __init__(self, start:int, end:int, step:int, axis_type:AxisType, rid:int):
self.start, self.end, self.step = start, end, step
self.axis_type, self.done = axis_type, False
self.axis_type, self.rid, self.done = axis_type, rid, False
def __iter__(self): return self
def __next__(self):
if not self.done:
self.done = True
_tk_range.user_rid += 1
self._rng = UOp.range(self.end // self.step, _tk_range.user_rid-1, axis_type=self.axis_type) * self.step + self.start
self._rng = UOp.range(self.end // self.step, self.rid, axis_type=self.axis_type) * self.step + self.start
return self._rng
raise StopIteration
class Kernel(AbstractContextManager):
def __init__(self, grid_size:tuple[int, int, int], block_size:int):
def __init__(self, name:str, grid_size:tuple[int, int, int], block_size:int):
self.name = name
self.blockIdx_x = UOp.special(grid_size[0], "gidx0")
self.blockIdx_y = UOp.special(grid_size[1], "gidx1")
self.blockIdx_z = UOp.special(grid_size[2], "gidx2")
@@ -31,6 +31,7 @@ class Kernel(AbstractContextManager):
self.global_slot = 0
self.shared_slot = 0
self.register_slot = 0
self.range_id = 0
self.allocs = {}
@property
@@ -49,7 +50,8 @@ class Kernel(AbstractContextManager):
def range(self, start:int, end:int=0, step:int=1, axis_type:AxisType=AxisType.LOOP, track:bool=True):
if end == 0: start, end = 0, start
rng = _tk_range(start, end, step, axis_type)
rng = _tk_range(start, end, step, axis_type, self.range_id)
self.range_id += 1
if track: self.range_stack.append(rng)
return rng
@@ -89,7 +91,7 @@ class Kernel(AbstractContextManager):
if hasattr(last_store, '_uop'): uop = last_store._uop
else: uop = last_store
return uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
return uop.end(*rngs).sink(arg=KernelInfo(name=self.name, opts_to_apply=())).simplify()
def endrange(self):
last_store = self.store_stack.pop()