diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 39be6b646f..ac64a39f2e 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -383,7 +383,7 @@ class TestHandCodedOpts(unittest.TestCase): k = Linearizer(s.ast) k.hand_coded_optimizations() - assert len(k.group_for_reduce) == 1 + assert k.group_for_reduces == 1 assert k.local_dims == 1 assert k.upcasted == 1 @@ -612,9 +612,14 @@ class TestLinearizerOpts(unittest.TestCase): opts_shapes = [ ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]), ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]), - # TODO: fix these broken transformations - # ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), - # ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), + # check to ensure local_dims are stable for full UNROLL of first_reduce + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + # check behavior for full UNROLL on an existing GROUP + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]), + ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), + ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), ] helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes]) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 450997482c..82c242186b 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -97,7 +97,7 @@ class Kernel: # parameters for optimization self.applied_opts: List[Opt] = [] - self.group_for_reduce: List[int] = [] + self.group_for_reduces: int = 0 self.upcasted: int = 0 self.local_dims: int = 0 self.local_alias: Dict[int, LocalBuffer] = {} @@ -123,8 +123,8 @@ class Kernel: self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:] # parameters for optimizations - ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \ - self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals + ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \ + self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals # uncached since linearize didn't run ret.applied_opts_cache = None @@ -172,7 +172,7 @@ class Kernel: @property def upcast_in_mid_reduce_axes(self) -> List[int]: - return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] + return [j for j in range(self.first_reduce, self.first_reduce+self.group_for_reduces) if self.full_shape[j] == self.sts[0].shape[j]] @property def global_dims(self) -> int: return self.first_reduce-self.local_dims @@ -192,10 +192,10 @@ class Kernel: colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan) colors += ["cyan"] * self.local_dims - # between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green) - colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] # noqa: E501 - # between first_reduce + group_for_reduce and upcasted, they are reduce (red) - colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce))) + # between first_reduce and first_reduce + group_for_reduces, they are either upcast mid reduce (white), or late upcasted (green) + colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + self.group_for_reduces)] # noqa: E501 + # between first_reduce + group_for_reduces and upcasted, they are reduce (red) + colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + self.group_for_reduces)) # upcasted dimensions are reduce (magenta) or normal (yellow) colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)] assert len(colors) == self.shape_len, "colors size mismatch" @@ -399,7 +399,7 @@ class Kernel: assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501 self.applied_opts.append(opt) if opt.axis is not None: - axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501 + axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501 else: axis = -1 if opt.amt is not None: @@ -419,16 +419,18 @@ class Kernel: self.local_dims += 1 elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem" - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" + assert axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group" assert not self.tensor_core, "can't group with tensor cores" - self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) + self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces) + self.group_for_reduces += 1 elif opt.op == OptOps.UNROLL: # purple assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted" assert amt <= 32, "don't unroll more than 32" # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores #upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0 #self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count) + if self.full_shape[axis] == amt and axis == self.first_reduce: self.local_dims += 1 # first_reduce will ++, so offset loss in simplify_ones + if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP self.shift_to(axis, amt, insert_before=None) self.upcast() elif opt.op == OptOps.UPCAST: # yellow @@ -437,16 +439,16 @@ class Kernel: self.shift_to(axis, amt, insert_before=None) self.upcast() elif opt.op == OptOps.UPCASTMID: # white - assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501 + assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" # noqa: E501 axes = self.sts[0].unit_stride_axes() assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" assert axes[0] == axis, "wrong axis" assert amt == 4, "don't upcast mid anything but 4" - self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) + self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces) + self.group_for_reduces += 1 elif opt.op == OptOps.NOLOCALS: assert self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals" - assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals" + assert self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals" self.dont_use_locals = True elif opt.op == OptOps.PADTO: assert not self.ast.vars(), "does not work with symbolic shape" @@ -499,7 +501,7 @@ class Kernel: break # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501 + if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: # noqa: E501 axes = self.sts[0].unit_stride_axes() assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" if self.sts[0].shape[axes[0]]%4 == 0: @@ -517,7 +519,7 @@ class Kernel: self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4)) # no more opt if we are grouping - if self.group_for_reduce: return + if self.group_for_reduces: return # **** below this line need to be optional and benchmarked **** @@ -574,7 +576,7 @@ class Kernel: # **** local groups **** if self.opts.has_local: - if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce: + if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces: self.apply_opt(Opt(OptOps.NOLOCALS)) else: # prioritize making expand axes local diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index a6c4f0cf41..f611535255 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -165,7 +165,7 @@ class Linearizer(Kernel): if self.applied_opts == self.applied_opts_cache: return self # save backups - sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted + sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted # global uop cache self.saved_exprs: Dict[Tuple, UOp] = dict() @@ -190,9 +190,9 @@ class Linearizer(Kernel): for lb in self.local_alias.values(): self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size)) # add a local buffer for multistage reduce. # TODO: use local alias - if self.group_for_reduce: + if self.group_for_reduces: # TODO: the strides of this can be controlled - self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501 + self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501 temp_dtype = self.get_base_dtype(get_lazyop_info(self.reduceop).dtype) self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype)) self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size))) @@ -207,7 +207,7 @@ class Linearizer(Kernel): # define indexes global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0) - local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) # noqa: E501 + local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501 full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])] upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])] @@ -241,7 +241,7 @@ class Linearizer(Kernel): fake_reduce_idxs: List[Variable] = [] if self.reduceop is not None: # define indexes - reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] # noqa: E501 + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501 fake_reduce_idxs = [x*0 for x in reduce_idxs] # define accumulator @@ -321,7 +321,7 @@ class Linearizer(Kernel): self.load_cache.clear() # end the local loop, do the local reduce - if self.group_for_reduce: + if self.group_for_reduces: fake_global_idxs = [x*0 for x in global_idxs] stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False) @@ -332,14 +332,14 @@ class Linearizer(Kernel): barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False) # create new late reduce local loops and replace local_idxs that have been used - end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] # noqa: E501 + end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501 local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:] # if any group_for_reduce items aren't reduces, upcast them here for j in self.upcast_in_mid_reduce_axes: self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j]) self.upcast() - self.group_for_reduce.pop() + self.group_for_reduces -= 1 local_idxs = local_idxs[:-1] end_local_idxs = end_local_idxs[:-1] # regenerate upcast_idxs @@ -364,7 +364,7 @@ class Linearizer(Kernel): # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have # been rewritten with fake end_local_idxs. - local_idxs = local_idxs[:self.local_dims] + [NumNode(0) for i in range(len(self.group_for_reduce))] + local_idxs = local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)] # load latebufs loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # noqa: E501 @@ -387,7 +387,7 @@ class Linearizer(Kernel): graph_uops(self.uops) # restore backups - self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup + self.sts, self.group_for_reduces, self.upcasted = sts_backup, gfr_backup, upc_backup # set cache and return self.applied_opts_cache = self.applied_opts[:]