mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove Kernel.global_dims [pr] (#11267)
all reference to global used axis_types, so we don't need number of global helper that was used to locate GLOBAL
This commit is contained in:
@@ -129,8 +129,6 @@ class Kernel:
|
||||
|
||||
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
|
||||
@property
|
||||
def global_dims(self) -> int: return len(self.axes_of(AxisType.GLOBAL))
|
||||
@property
|
||||
def local_dims(self) -> int: return len(self.axes_of(AxisType.LOCAL))
|
||||
@property
|
||||
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
@@ -307,7 +305,8 @@ class Kernel:
|
||||
self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None)
|
||||
elif opt.op is OptOps.UPCAST: # yellow
|
||||
check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}")
|
||||
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
|
||||
# NOTE: assume the first get_local_axes() LOCAL are for TC
|
||||
check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals")
|
||||
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
||||
self.shift_to(axis, amt, AxisType.UPCAST, insert_at=None)
|
||||
elif opt.op is OptOps.NOLOCALS:
|
||||
@@ -494,15 +493,16 @@ class Kernel:
|
||||
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
if self.group_for_reduces and grouped_axes:
|
||||
local_shape = tuple([s if self.axis_types[i] not in (AxisType.GLOBAL, AxisType.REDUCE, AxisType.UNROLL) and \
|
||||
(self.axis_types[i] is not AxisType.GROUP_REDUCE or i in grouped_axes) else 1 for i,s in enumerate(self.full_shape)])
|
||||
st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:self.global_dims]+local_shape[self.global_dims:])
|
||||
local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
|
||||
local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
|
||||
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
|
||||
st = ShapeTracker.from_shape(local_shape).expand(local_src_shape)
|
||||
local_size = st.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)]))
|
||||
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)]))
|
||||
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
|
||||
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user