axis cleanups [pr] (#11142)

This commit is contained in:
George Hotz
2025-07-08 17:07:26 -07:00
committed by GitHub
parent 7ce9e45474
commit b11ca104e9

View File

@@ -16,12 +16,7 @@ from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.kernelize.kernelize import view_left
class AxisType(Enum):
GLOBAL = auto()
LOCAL = auto()
GROUP_REDUCE = auto()
REDUCE = auto()
UPCAST = auto()
UNROLL = auto()
GLOBAL = auto(); LOCAL = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan",
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red",
@@ -53,8 +48,6 @@ class Kernel:
# verify AST matches the spec
if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
self.vars: list[Variable] = self.ast.variables()
# NOTE: this requires a specific order with the [::-1], this is likely a bug
self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer][::-1]
@@ -64,6 +57,7 @@ class Kernel:
# add the shapetrackers for each reduce
# we use this to track which axes are reduced in each reduce
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
for x in self.reduceops:
self.sts.append(unwrap(x.st))
self.sts.append(unwrap(x.src[0].st))
@@ -77,8 +71,6 @@ class Kernel:
self.use_tensor_cores: int = 0
self.applied_opts: list[Opt] = []
self.dont_use_locals = False
# finalized means you can't optimize anymore
self.finalized: bool = False
# group simplifies
@@ -122,23 +114,21 @@ class Kernel:
@property
def first_reduce(self) -> int:
return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
for i in range(self.first_upcast):
if self.axis_types[i] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): return i
return self.first_upcast
@property
def first_upcast(self) -> int: return self.shape_len-self.upcasted
@property
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
@property
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
@property
def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape
@property
def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
@property
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
@property
def shape_len(self) -> int: return len(self.sts[0].shape)
@@ -151,25 +141,9 @@ class Kernel:
@property
def group_for_reduces(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GROUP_REDUCE]) if hasattr(self, 'axis_types') else 0
def _legacy_colors(self) -> list[str]:
# first non local non reduce dims are global (blue)
colors = ["blue"] * self.global_dims
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
colors += ["cyan"] * self.local_dims
# between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
colors += ["green"] * self.group_for_reduces
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
# upcasted dimensions are reduce (magenta) or normal (yellow)
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colors(self) -> list[str]:
assert len(self.axis_types) == self.shape_len, "colors size mismatch"
ret = [axis_colors[x] for x in self.axis_types]
assert self._legacy_colors() == ret, f"legacy colors mismatch colors {self._legacy_colors()} != {ret}"
return ret
return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
@@ -212,6 +186,8 @@ class Kernel:
def simplify_merge_adjacent(self):
if self.shape_len == 0: return
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
# NOTE: we can't use self.first_reduce yet
first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
if isinstance(self.membufs[0].dtype, ImageDType):
@@ -238,7 +214,7 @@ class Kernel:
si, sti, last_st = s[i], st[i], ret[-1][1]
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
mergeable = all(can_merge) and i != first_reduce
for j,(s,st) in enumerate(zip(shapes, strides)):
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
else: rets[j].append((s[i], st[i]))