From e154a66f4388227d6438b61fee43fea0e75afc7f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 9 Jul 2025 17:28:23 -0700 Subject: [PATCH] unroll axis 0 in tensor core (#11155) * unroll is 0 in tc [pr] * flip order of upcast/reduce in tensor core * Revert "flip order of upcast/reduce in tensor core" This reverts commit e564e38bcd23874f30c4d57b5d1e07a180ac5759. --- tinygrad/opt/kernel.py | 4 ++-- tinygrad/opt/tc.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index b684361c91..ceebd20686 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -389,7 +389,7 @@ class Kernel: for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail except KernelOptError: continue # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M) - for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False) + for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0 for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False) self.tensor_core = tc self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA @@ -464,7 +464,7 @@ class Kernel: def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast) upcast_axes = int(math.log2(tc.elements_per_thread[buf])) return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes)) - def get_tc_swizzle_st(shape, local_perm, reduce_perm, upcast_perm): + def get_tc_swizzle_st(shape, local_perm, upcast_perm, reduce_perm): ru_perm = reduce_perm + upcast_perm offset = (tcd - (wd + len(local_perm))) permaxis = list(range(wd)) \ diff --git a/tinygrad/opt/tc.py b/tinygrad/opt/tc.py index 3f2910f3f1..33fdb06fe7 100644 --- a/tinygrad/opt/tc.py +++ b/tinygrad/opt/tc.py @@ -18,7 +18,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes()) fwd_st = [f"l{i}" for i in range(local_axes)] + [f"r{i}" for i in range(reduce_axes)] + [f"u{i}" for i in range(upcast_axes)] st = {s:i for i,s in enumerate(fwd_st)} - return tuple((tuple([st[x] for x in s[0]]), tuple([st[x] for x in s[2]]), tuple([st[x] for x in s[1]])) for s in self.swizzle) + return tuple((tuple([st[x] for x in s[0]]), tuple([st[x] for x in s[1]]), tuple([st[x] for x in s[2]])) for s in self.swizzle) def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))] def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"] def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]