diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 5c51904417..3068d301e8 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -16,12 +16,7 @@ from tinygrad.shape.view import strides_for_shape, get_contraction from tinygrad.kernelize.kernelize import view_left class AxisType(Enum): - GLOBAL = auto() - LOCAL = auto() - GROUP_REDUCE = auto() - REDUCE = auto() - UPCAST = auto() - UNROLL = auto() + GLOBAL = auto(); LOCAL = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan", AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red", @@ -53,8 +48,6 @@ class Kernel: # verify AST matches the spec if __debug__: type_verify(list(self.ast.toposort()), ast_spec) - self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS] - self.vars: list[Variable] = self.ast.variables() # NOTE: this requires a specific order with the [::-1], this is likely a bug self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer][::-1] @@ -64,6 +57,7 @@ class Kernel: # add the shapetrackers for each reduce # we use this to track which axes are reduced in each reduce + self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS] for x in self.reduceops: self.sts.append(unwrap(x.st)) self.sts.append(unwrap(x.src[0].st)) @@ -77,8 +71,6 @@ class Kernel: self.use_tensor_cores: int = 0 self.applied_opts: list[Opt] = [] self.dont_use_locals = False - - # finalized means you can't optimize anymore self.finalized: bool = False # group simplifies @@ -122,23 +114,21 @@ class Kernel: @property def first_reduce(self) -> int: - return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True) - + for i in range(self.first_upcast): + if self.axis_types[i] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): return i + return self.first_upcast @property def first_upcast(self) -> int: return self.shape_len-self.upcasted @property def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None - - @property - def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape - @property def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape - @property def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast] + @property + def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape @property def shape_len(self) -> int: return len(self.sts[0].shape) @@ -151,25 +141,9 @@ class Kernel: @property def group_for_reduces(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GROUP_REDUCE]) if hasattr(self, 'axis_types') else 0 - def _legacy_colors(self) -> list[str]: - # first non local non reduce dims are global (blue) - colors = ["blue"] * self.global_dims - # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan) - colors += ["cyan"] * self.local_dims - # between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green) - colors += ["green"] * self.group_for_reduces - # between first_reduce + group_for_reduces and upcasted, they are reduce (red) - colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces)) - # upcasted dimensions are reduce (magenta) or normal (yellow) - colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)] - assert len(colors) == self.shape_len, "colors size mismatch" - return colors - def colors(self) -> list[str]: assert len(self.axis_types) == self.shape_len, "colors size mismatch" - ret = [axis_colors[x] for x in self.axis_types] - assert self._legacy_colors() == ret, f"legacy colors mismatch colors {self._legacy_colors()} != {ret}" - return ret + return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types] def colored_shape(self, pad:Optional[int]=None, dense=False) -> str: shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape] @@ -212,6 +186,8 @@ class Kernel: def simplify_merge_adjacent(self): if self.shape_len == 0: return shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] + # NOTE: we can't use self.first_reduce yet + first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True) # if it's an image, insert fake strides such that this fusion doesn't happen across image axes if isinstance(self.membufs[0].dtype, ImageDType): @@ -238,7 +214,7 @@ class Kernel: si, sti, last_st = s[i], st[i], ret[-1][1] can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0))) # more can merge than this - mergeable = all(can_merge) and i != self.first_reduce + mergeable = all(can_merge) and i != first_reduce for j,(s,st) in enumerate(zip(shapes, strides)): if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i]) else: rets[j].append((s[i], st[i]))