use shape str for tensor cores upcast/reduce [pr] (#11168)

* use shape str for tensor cores upcast/reduce [pr]

* reduce axis count isn't fixed
This commit is contained in:
George Hotz
2025-07-10 13:10:58 -07:00
committed by GitHub
parent cc6ed30f4f
commit 05613c8cac

View File

@@ -445,6 +445,7 @@ class Kernel:
cnt[x] = (cnt[x] + 1) if x in cnt else 0
ret.append(f"{axis_letters[x]}{cnt[x]}")
return ret
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
@functools.cache
@@ -470,18 +471,19 @@ class Kernel:
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
tcd = self.first_upcast
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))] + \
[f"u{i}" for i in range(len(tc.get_upcast_axes()))])])[::-1]
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
# permute the srcs
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))):
src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(tuple(permaxis)))
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
# construct the op
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
@@ -489,6 +491,7 @@ class Kernel:
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
# preserve any other reduce
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
ret = ret.replace(arg = (op.arg[0], axes))