mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
axis cleanups [pr] (#11142)
This commit is contained in:
@@ -16,12 +16,7 @@ from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.kernelize.kernelize import view_left
|
||||
|
||||
class AxisType(Enum):
|
||||
GLOBAL = auto()
|
||||
LOCAL = auto()
|
||||
GROUP_REDUCE = auto()
|
||||
REDUCE = auto()
|
||||
UPCAST = auto()
|
||||
UNROLL = auto()
|
||||
GLOBAL = auto(); LOCAL = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.LOCAL: "cyan",
|
||||
AxisType.GROUP_REDUCE: "green", AxisType.REDUCE: "red",
|
||||
@@ -53,8 +48,6 @@ class Kernel:
|
||||
# verify AST matches the spec
|
||||
if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
|
||||
|
||||
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
|
||||
|
||||
self.vars: list[Variable] = self.ast.variables()
|
||||
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
||||
self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer][::-1]
|
||||
@@ -64,6 +57,7 @@ class Kernel:
|
||||
|
||||
# add the shapetrackers for each reduce
|
||||
# we use this to track which axes are reduced in each reduce
|
||||
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
|
||||
for x in self.reduceops:
|
||||
self.sts.append(unwrap(x.st))
|
||||
self.sts.append(unwrap(x.src[0].st))
|
||||
@@ -77,8 +71,6 @@ class Kernel:
|
||||
self.use_tensor_cores: int = 0
|
||||
self.applied_opts: list[Opt] = []
|
||||
self.dont_use_locals = False
|
||||
|
||||
# finalized means you can't optimize anymore
|
||||
self.finalized: bool = False
|
||||
|
||||
# group simplifies
|
||||
@@ -122,23 +114,21 @@ class Kernel:
|
||||
|
||||
@property
|
||||
def first_reduce(self) -> int:
|
||||
return [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
|
||||
|
||||
for i in range(self.first_upcast):
|
||||
if self.axis_types[i] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): return i
|
||||
return self.first_upcast
|
||||
@property
|
||||
def first_upcast(self) -> int: return self.shape_len-self.upcasted
|
||||
|
||||
@property
|
||||
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
||||
|
||||
@property
|
||||
def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape
|
||||
|
||||
@property
|
||||
def full_unupcasted_shape(self) -> tuple[sint, ...]: return self.full_shape[:self.first_upcast]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
||||
@property
|
||||
def shape_len(self) -> int: return len(self.sts[0].shape)
|
||||
|
||||
@@ -151,25 +141,9 @@ class Kernel:
|
||||
@property
|
||||
def group_for_reduces(self) -> int: return sum([1 for x in self.axis_types if x == AxisType.GROUP_REDUCE]) if hasattr(self, 'axis_types') else 0
|
||||
|
||||
def _legacy_colors(self) -> list[str]:
|
||||
# first non local non reduce dims are global (blue)
|
||||
colors = ["blue"] * self.global_dims
|
||||
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
# between first_reduce and first_reduce + group_for_reduces, they are late upcasted (green)
|
||||
colors += ["green"] * self.group_for_reduces
|
||||
# between first_reduce + group_for_reduces and upcasted, they are reduce (red)
|
||||
colors += ["red"] * (self.first_upcast - (self.first_reduce + self.group_for_reduces))
|
||||
# upcasted dimensions are reduce (magenta) or normal (yellow)
|
||||
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.first_upcast, self.shape_len)]
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colors(self) -> list[str]:
|
||||
assert len(self.axis_types) == self.shape_len, "colors size mismatch"
|
||||
ret = [axis_colors[x] for x in self.axis_types]
|
||||
assert self._legacy_colors() == ret, f"legacy colors mismatch colors {self._legacy_colors()} != {ret}"
|
||||
return ret
|
||||
return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
|
||||
|
||||
def colored_shape(self, pad:Optional[int]=None, dense=False) -> str:
|
||||
shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
|
||||
@@ -212,6 +186,8 @@ class Kernel:
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
# NOTE: we can't use self.first_reduce yet
|
||||
first_reduce = [resolve(x!=y) for x,y in zip(self.sts[0].shape[:self.first_upcast]+(0,), self.full_shape[:self.first_upcast]+(1,))].index(True)
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
if isinstance(self.membufs[0].dtype, ImageDType):
|
||||
@@ -238,7 +214,7 @@ class Kernel:
|
||||
si, sti, last_st = s[i], st[i], ret[-1][1]
|
||||
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
mergeable = all(can_merge) and i != first_reduce
|
||||
for j,(s,st) in enumerate(zip(shapes, strides)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
||||
else: rets[j].append((s[i], st[i]))
|
||||
|
||||
Reference in New Issue
Block a user