mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tk: named kernels + per kernel range id (#13522)
This commit is contained in:
@@ -5,21 +5,21 @@ from extra.thunder.tiny.tk.group import Group
|
||||
from extra.thunder.tiny.tk.tiles import GL, ST_16X16, ST_16X16_SWIZZLED, ST, RT_16X16, RT, RV, TileLayout, VecLayout
|
||||
|
||||
class _tk_range:
|
||||
user_rid = 0
|
||||
def __init__(self, start:int, end:int, step:int, axis_type:AxisType):
|
||||
def __init__(self, start:int, end:int, step:int, axis_type:AxisType, rid:int):
|
||||
self.start, self.end, self.step = start, end, step
|
||||
self.axis_type, self.done = axis_type, False
|
||||
self.axis_type, self.rid, self.done = axis_type, rid, False
|
||||
def __iter__(self): return self
|
||||
def __next__(self):
|
||||
if not self.done:
|
||||
self.done = True
|
||||
_tk_range.user_rid += 1
|
||||
self._rng = UOp.range(self.end // self.step, _tk_range.user_rid-1, axis_type=self.axis_type) * self.step + self.start
|
||||
self._rng = UOp.range(self.end // self.step, self.rid, axis_type=self.axis_type) * self.step + self.start
|
||||
return self._rng
|
||||
raise StopIteration
|
||||
|
||||
class Kernel(AbstractContextManager):
|
||||
def __init__(self, grid_size:tuple[int, int, int], block_size:int):
|
||||
def __init__(self, name:str, grid_size:tuple[int, int, int], block_size:int):
|
||||
self.name = name
|
||||
|
||||
self.blockIdx_x = UOp.special(grid_size[0], "gidx0")
|
||||
self.blockIdx_y = UOp.special(grid_size[1], "gidx1")
|
||||
self.blockIdx_z = UOp.special(grid_size[2], "gidx2")
|
||||
@@ -31,6 +31,7 @@ class Kernel(AbstractContextManager):
|
||||
self.global_slot = 0
|
||||
self.shared_slot = 0
|
||||
self.register_slot = 0
|
||||
self.range_id = 0
|
||||
self.allocs = {}
|
||||
|
||||
@property
|
||||
@@ -49,7 +50,8 @@ class Kernel(AbstractContextManager):
|
||||
|
||||
def range(self, start:int, end:int=0, step:int=1, axis_type:AxisType=AxisType.LOOP, track:bool=True):
|
||||
if end == 0: start, end = 0, start
|
||||
rng = _tk_range(start, end, step, axis_type)
|
||||
rng = _tk_range(start, end, step, axis_type, self.range_id)
|
||||
self.range_id += 1
|
||||
if track: self.range_stack.append(rng)
|
||||
return rng
|
||||
|
||||
@@ -89,7 +91,7 @@ class Kernel(AbstractContextManager):
|
||||
if hasattr(last_store, '_uop'): uop = last_store._uop
|
||||
else: uop = last_store
|
||||
|
||||
return uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
return uop.end(*rngs).sink(arg=KernelInfo(name=self.name, opts_to_apply=())).simplify()
|
||||
|
||||
def endrange(self):
|
||||
last_store = self.store_stack.pop()
|
||||
|
||||
Reference in New Issue
Block a user