unroll axis 0 in tensor core (#11155)

* unroll is 0 in tc [pr]

* flip order of upcast/reduce in tensor core

* Revert "flip order of upcast/reduce in tensor core"

This reverts commit e564e38bcd.
This commit is contained in:
George Hotz
2025-07-09 17:28:23 -07:00
committed by GitHub
parent b7742ad9e4
commit e154a66f43
2 changed files with 3 additions and 3 deletions

View File

@@ -389,7 +389,7 @@ class Kernel:
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
except KernelOptError: continue
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
self.tensor_core = tc
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
@@ -464,7 +464,7 @@ class Kernel:
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
def get_tc_swizzle_st(shape, local_perm, reduce_perm, upcast_perm):
def get_tc_swizzle_st(shape, local_perm, upcast_perm, reduce_perm):
ru_perm = reduce_perm + upcast_perm
offset = (tcd - (wd + len(local_perm)))
permaxis = list(range(wd)) \

View File

@@ -18,7 +18,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
fwd_st = [f"l{i}" for i in range(local_axes)] + [f"r{i}" for i in range(reduce_axes)] + [f"u{i}" for i in range(upcast_axes)]
st = {s:i for i,s in enumerate(fwd_st)}
return tuple((tuple([st[x] for x in s[0]]), tuple([st[x] for x in s[2]]), tuple([st[x] for x in s[1]])) for s in self.swizzle)
return tuple((tuple([st[x] for x in s[0]]), tuple([st[x] for x in s[1]]), tuple([st[x] for x in s[2]])) for s in self.swizzle)
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]