mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-05 05:04:27 -05:00
tk: no global copy and clear ranges (#13988)
This commit is contained in:
@@ -24,13 +24,11 @@ class Group:
|
||||
|
||||
# ops that only work on a single warp
|
||||
|
||||
clear_rid = 1000000
|
||||
def clear(self, reg:ALL_TILES, value:float=0):
|
||||
reg = cast(UOp, reg)
|
||||
assert self.warps == 1
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
|
||||
Group.clear_rid += len(reg.shape)
|
||||
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in reg.shape)
|
||||
|
||||
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
|
||||
|
||||
@@ -41,14 +39,12 @@ class Group:
|
||||
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
|
||||
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
|
||||
|
||||
copy_rid = 3000000
|
||||
def copy(self, dst:ALL_TILES, src:ALL_TILES):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
assert dst.shape == src.shape
|
||||
|
||||
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
|
||||
Group.copy_rid += len(dst.shape)
|
||||
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in dst.shape)
|
||||
|
||||
src_load = src[*rngs_for_shape]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
@@ -219,8 +215,7 @@ class Group:
|
||||
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
i = self.ker.raw_range(red_reg.size)
|
||||
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
|
||||
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
@@ -254,8 +249,7 @@ class Group:
|
||||
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
||||
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
||||
Group.clear_rid += 1
|
||||
i = self.ker.raw_range(red_reg.size)
|
||||
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
|
||||
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||
|
||||
@@ -55,6 +55,11 @@ class Kernel(AbstractContextManager):
|
||||
if track: self.range_stack.append(rng)
|
||||
return rng
|
||||
|
||||
def raw_range(self, end:int=0, axis_type:AxisType=AxisType.LOOP):
|
||||
rng = UOp.range(end, self.range_id, axis_type=axis_type)
|
||||
self.range_id += 1
|
||||
return rng
|
||||
|
||||
def alloc(self, shape, dtype, addrspace:AddrSpace, name:str|None=None):
|
||||
match addrspace:
|
||||
case AddrSpace.GLOBAL:
|
||||
|
||||
Reference in New Issue
Block a user