diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a7d4985774..ee9f585f73 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -844,8 +844,8 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] k = Kernel(s.ast) - k.shift_to(k.first_upcast-1, 4, AxisType.UPCAST) # manual trigger float4 dim - k.shift_to(k.first_upcast-1, shift, AxisType.UPCAST, insert_before=k.shape_len-1) + k.shift_to(1, 4, AxisType.UPCAST) # manual trigger float4 dim + k.shift_to(1, shift, AxisType.UPCAST, insert_before=k.shape_len-1) return get_program(k.get_optimized_ast(), k.opts).uops sizes = [13, 9, 17] diff --git a/tinygrad/opt/heuristic.py b/tinygrad/opt/heuristic.py index 9146f44cf2..f6a87af731 100644 --- a/tinygrad/opt/heuristic.py +++ b/tinygrad/opt/heuristic.py @@ -66,9 +66,9 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: # consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP) for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]): # if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already - if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and \ - any(st.views[-1].strides[axis] == 0 and not any(x == 0 for x in st.real_strides()[k.first_upcast:]) \ - for st in k.sts): + if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue + if any(st.views[-1].strides[axis] == 0 and \ + all(x != 0 for t,x in zip(k.axis_types, st.real_strides()) if t in (AxisType.UPCAST, AxisType.UNROLL)) for st in k.sts): xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount)) if xb_choices: @@ -79,8 +79,8 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: else: break # if last reduce dim is small(ish), loop unroll the reduce - if k.unrollable_dims and \ - (prod(k.full_shape[k.first_upcast:]) <= 4 or (AxisType.UNROLL not in k.axis_types)) and (prod(k.full_shape[k.first_upcast:]) < 64): + upcast_size = prod(s for s,t in zip(k.full_shape, k.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)) + if k.unrollable_dims and (upcast_size <= 4 or (AxisType.UNROLL not in k.axis_types)) and (upcast_size < 64): if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32: k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims[-1]-k.first_reduce, 0)) # if it's small, upcast a second reduce dimension too diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index c84d971560..0968c3264b 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -207,10 +207,11 @@ class Kernel: return False def simplify_merge_adjacent(self): + assert not hasattr(self, 'axis_types'), "don't called this after init" if self.shape_len == 0: return shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] # NOTE: we can't use self.first_reduce yet - first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True) + first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape+(0,), self.full_shape+(1,))].index(True) # if it's an image, insert fake strides such that this fusion doesn't happen across image axes # TODO: remove membufs @@ -285,7 +286,7 @@ class Kernel: if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): acc_sz = self.reduceop.dtype.itemsize - upcast_sz = prod([a for a,b in zip(self.full_shape[self.first_upcast:], self.sts[0].shape[self.first_upcast:]) if a == b]) + upcast_sz = prod([s for s,t in zip(self.full_shape, self.axis_types) if t is AxisType.UPCAST]) local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces]) smem_sz = amt*acc_sz*upcast_sz*local_sz check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") @@ -298,12 +299,12 @@ class Kernel: self.shift_to(axis, amt, AxisType.LOCAL, insert_before=self.first_reduce) elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") - check(self.first_reduce + self.group_for_reduces <= axis < self.first_upcast, "must be reduce axis to group") + check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group") check(not self.tensor_core, "can't group with tensor cores") check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces") self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces) elif opt.op is OptOps.UNROLL: # purple - check(axis < self.first_upcast, "can't upcasted already upcasted") + check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted") check(amt <= 32, "don't unroll more than 32") self.shift_to(axis, amt, AxisType.UNROLL, insert_before=None) elif opt.op is OptOps.UPCAST: # yellow @@ -322,7 +323,7 @@ class Kernel: self.permute(tuple(permute)) elif opt.op is OptOps.PADTO: check(not self.vars, "does not work with symbolic shape") - check(axis < self.first_upcast, "cannot pad upcasted") + check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "cannot pad upcasted") # ok to pad SUM if all parent ALU ops have f(0) = 0 if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}") padded = False