diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 4aad735e7f..7696983467 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -445,6 +445,7 @@ class Kernel: cnt[x] = (cnt[x] + 1) if x in cnt else 0 ret.append(f"{axis_letters[x]}{cnt[x]}") return ret + def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms]) def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp: @functools.cache @@ -470,18 +471,19 @@ class Kernel: grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces) if (tc := self.tensor_core) and self.use_tensor_cores == 1: - tcd = self.first_upcast - def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast) - upcast_axes = int(math.log2(tc.elements_per_thread[buf])) - return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes)) + # get reduce/upcast axes for the tensor cores + tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) + base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))] + \ + [f"u{i}" for i in range(len(tc.get_upcast_axes()))])])[::-1] + tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)]) + # permute the srcs srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))): src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(tuple(permaxis))) - tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes()) - tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2)) + # construct the op wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes) wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=( UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]), @@ -489,6 +491,7 @@ class Kernel: UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg) tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2]) + # preserve any other reduce return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop ret = ret.replace(arg = (op.arg[0], axes))