remove Kernel.global_dims [pr] (#11267)

all reference to global used axis_types, so we don't need number of global helper that was used to locate GLOBAL
This commit is contained in:
chenyu
2025-07-16 17:16:49 -04:00
committed by GitHub
parent 6f0ddcc24c
commit d8c783f65f

View File

@@ -129,8 +129,6 @@ class Kernel:
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
@property
def global_dims(self) -> int: return len(self.axes_of(AxisType.GLOBAL))
@property
def local_dims(self) -> int: return len(self.axes_of(AxisType.LOCAL))
@property
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
@@ -307,7 +305,8 @@ class Kernel:
self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None)
elif opt.op is OptOps.UPCAST: # yellow
check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}")
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
# NOTE: assume the first get_local_axes() LOCAL are for TC
check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals")
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
self.shift_to(axis, amt, AxisType.UPCAST, insert_at=None)
elif opt.op is OptOps.NOLOCALS:
@@ -494,15 +493,16 @@ class Kernel:
ret = ret.replace(arg = (op.arg[0], axes))
if self.group_for_reduces and grouped_axes:
local_shape = tuple([s if self.axis_types[i] not in (AxisType.GLOBAL, AxisType.REDUCE, AxisType.UNROLL) and \
(self.axis_types[i] is not AxisType.GROUP_REDUCE or i in grouped_axes) else 1 for i,s in enumerate(self.full_shape)])
st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:self.global_dims]+local_shape[self.global_dims:])
local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
st = ShapeTracker.from_shape(local_shape).expand(local_src_shape)
local_size = st.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
if op is self.reduceops[-1]: return grouped_reduce
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)]))
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else s for i,s in enumerate(local_shape)]))
return local_buffer.view(st).load(local_buffer.view(st).store(grouped_reduce))
return ret