From ccd382bc6f2301f9b9bfbcd7fe23429d4dac2dbb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:05:13 -0700 Subject: [PATCH] use axis_types more [pr] (#11172) * use axis_types more * fix local shape * simpler clause * fix local shape --- tinygrad/opt/kernel.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 7b56ce4155..349d826662 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -464,12 +464,10 @@ class Kernel: self.upcasted, self.dont_use_locals, tuple(self.applied_opts))) if op.op is Ops.REDUCE_AXIS: reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2 - - def reduced_axes(start, stop): - return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i])) - axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len) - grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces) - + axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] in {AxisType.REDUCE, AxisType.UNROLL} and + resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i])) + grouped_axes = tuple(i for i in range(0, self.shape_len) if self.axis_types[i] is AxisType.GROUP_REDUCE and + resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i])) if (tc := self.tensor_core) and self.use_tensor_cores == 1: # get reduce/upcast axes for the tensor cores tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) @@ -496,10 +494,8 @@ class Kernel: ret = ret.replace(arg = (op.arg[0], axes)) if self.group_for_reduces and grouped_axes: - local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \ - tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \ - for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ - (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) + local_shape = tuple([s if self.axis_types[i] not in (AxisType.GLOBAL, AxisType.REDUCE, AxisType.UNROLL) and \ + (self.axis_types[i] is not AxisType.GROUP_REDUCE or i in grouped_axes) else 1 for i,s in enumerate(self.full_shape)]) st = ShapeTracker.from_shape(local_shape).expand(self.full_shape[:self.global_dims]+local_shape[self.global_dims:]) local_size = st.real_size() local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")