tk: no global copy and clear ranges (#13988)

This commit is contained in:
wozeparrot
2026-01-03 02:45:15 -05:00
committed by GitHub
parent 9f082e8e25
commit 6242a9d151
2 changed files with 9 additions and 10 deletions

View File

@@ -24,13 +24,11 @@ class Group:
# ops that only work on a single warp
clear_rid = 1000000
def clear(self, reg:ALL_TILES, value:float=0):
reg = cast(UOp, reg)
assert self.warps == 1
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
Group.clear_rid += len(reg.shape)
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in reg.shape)
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
@@ -41,14 +39,12 @@ class Group:
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
copy_rid = 3000000
def copy(self, dst:ALL_TILES, src:ALL_TILES):
dst, src = cast(UOp, dst), cast(UOp, src)
assert self.warps == 1
assert dst.shape == src.shape
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
Group.copy_rid += len(dst.shape)
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in dst.shape)
src_load = src[*rngs_for_shape]
if src.dtype.base != dst.dtype.base:
@@ -219,8 +215,7 @@ class Group:
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
for height in self.ker.range(src.shape[-3], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
i = self.ker.raw_range(red_reg.size)
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(init_value).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
@@ -254,8 +249,7 @@ class Group:
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
for width in self.ker.range(src.shape[-2], track=False):
i = UOp.range(red_reg.size, Group.clear_rid)
Group.clear_rid += 1
i = self.ker.raw_range(red_reg.size)
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
reg_store = red_reg.flatten()[i].store(init_value).end(i)
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)

View File

@@ -55,6 +55,11 @@ class Kernel(AbstractContextManager):
if track: self.range_stack.append(rng)
return rng
def raw_range(self, end:int=0, axis_type:AxisType=AxisType.LOOP):
rng = UOp.range(end, self.range_id, axis_type=axis_type)
self.range_id += 1
return rng
def alloc(self, shape, dtype, addrspace:AddrSpace, name:str|None=None):
match addrspace:
case AddrSpace.GLOBAL: