use axis_types more [pr] (#11172)

* use axis_types more

* fix local shape

* simpler clause

* fix local shape
This commit is contained in:
George Hotz
2025-07-10 15:05:13 -07:00
committed by GitHub
parent fb278c6a02
commit ccd382bc6f

View File

@@ -464,12 +464,10 @@ class Kernel:
self.upcasted, self.dont_use_locals, tuple(self.applied_opts)))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
def reduced_axes(start, stop):
return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] in {AxisType.REDUCE, AxisType.UNROLL} and
resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
grouped_axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] is AxisType.GROUP_REDUCE and
resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
@@ -496,10 +494,8 @@ class Kernel:
ret = ret.replace(arg = (op.arg[0], axes))
if self.group_for_reduces and grouped_axes:
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
local_shape = tuple([s if self.axis_types[i] not in (AxisType.GLOBAL, AxisType.REDUCE, AxisType.UNROLL) and \
(self.axis_types[i] is not AxisType.GROUP_REDUCE or i in grouped_axes) else 1 for i,s in enumerate(self.full_shape)])
st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:self.global_dims]+local_shape[self.global_dims:])
local_size = st.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")