mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
tk: no global copy and clear ranges (#13988)
This commit is contained in:
@@ -24,13 +24,11 @@ class Group:
|
|||||||
|
|
||||||
# ops that only work on a single warp
|
# ops that only work on a single warp
|
||||||
|
|
||||||
clear_rid = 1000000
|
|
||||||
def clear(self, reg:ALL_TILES, value:float=0):
|
def clear(self, reg:ALL_TILES, value:float=0):
|
||||||
reg = cast(UOp, reg)
|
reg = cast(UOp, reg)
|
||||||
assert self.warps == 1
|
assert self.warps == 1
|
||||||
|
|
||||||
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
|
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in reg.shape)
|
||||||
Group.clear_rid += len(reg.shape)
|
|
||||||
|
|
||||||
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
|
reg_store = reg[*rngs_for_shape].store(value).end(*rngs_for_shape)
|
||||||
|
|
||||||
@@ -41,14 +39,12 @@ class Group:
|
|||||||
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
|
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
|
||||||
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
|
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
|
||||||
|
|
||||||
copy_rid = 3000000
|
|
||||||
def copy(self, dst:ALL_TILES, src:ALL_TILES):
|
def copy(self, dst:ALL_TILES, src:ALL_TILES):
|
||||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||||
assert self.warps == 1
|
assert self.warps == 1
|
||||||
assert dst.shape == src.shape
|
assert dst.shape == src.shape
|
||||||
|
|
||||||
rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape))
|
rngs_for_shape = tuple(self.ker.raw_range(dim) for dim in dst.shape)
|
||||||
Group.copy_rid += len(dst.shape)
|
|
||||||
|
|
||||||
src_load = src[*rngs_for_shape]
|
src_load = src[*rngs_for_shape]
|
||||||
if src.dtype.base != dst.dtype.base:
|
if src.dtype.base != dst.dtype.base:
|
||||||
@@ -219,8 +215,7 @@ class Group:
|
|||||||
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
||||||
|
|
||||||
for height in self.ker.range(src.shape[-3], track=False):
|
for height in self.ker.range(src.shape[-3], track=False):
|
||||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
i = self.ker.raw_range(red_reg.size)
|
||||||
Group.clear_rid += 1
|
|
||||||
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
|
red_reg = red_reg.after(height, *[tkr._rng for tkr in self.ker.range_stack])
|
||||||
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
||||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||||
@@ -254,8 +249,7 @@ class Group:
|
|||||||
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
red_reg = self.ker.alloc((1,), src.dtype.base, AddrSpace.REG)
|
||||||
|
|
||||||
for width in self.ker.range(src.shape[-2], track=False):
|
for width in self.ker.range(src.shape[-2], track=False):
|
||||||
i = UOp.range(red_reg.size, Group.clear_rid)
|
i = self.ker.raw_range(red_reg.size)
|
||||||
Group.clear_rid += 1
|
|
||||||
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
|
red_reg = red_reg.after(width, *[tkr._rng for tkr in self.ker.range_stack])
|
||||||
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
reg_store = red_reg.flatten()[i].store(init_value).end(i)
|
||||||
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
red_reg = red_reg.after(reg_store).reshape(red_reg.shape)
|
||||||
|
|||||||
@@ -55,6 +55,11 @@ class Kernel(AbstractContextManager):
|
|||||||
if track: self.range_stack.append(rng)
|
if track: self.range_stack.append(rng)
|
||||||
return rng
|
return rng
|
||||||
|
|
||||||
|
def raw_range(self, end:int=0, axis_type:AxisType=AxisType.LOOP):
|
||||||
|
rng = UOp.range(end, self.range_id, axis_type=axis_type)
|
||||||
|
self.range_id += 1
|
||||||
|
return rng
|
||||||
|
|
||||||
def alloc(self, shape, dtype, addrspace:AddrSpace, name:str|None=None):
|
def alloc(self, shape, dtype, addrspace:AddrSpace, name:str|None=None):
|
||||||
match addrspace:
|
match addrspace:
|
||||||
case AddrSpace.GLOBAL:
|
case AddrSpace.GLOBAL:
|
||||||
|
|||||||
Reference in New Issue
Block a user