mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
use shape str for tensor cores upcast/reduce [pr] (#11168)
* use shape str for tensor cores upcast/reduce [pr] * reduce axis count isn't fixed
This commit is contained in:
@@ -445,6 +445,7 @@ class Kernel:
|
||||
cnt[x] = (cnt[x] + 1) if x in cnt else 0
|
||||
ret.append(f"{axis_letters[x]}{cnt[x]}")
|
||||
return ret
|
||||
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
|
||||
|
||||
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
|
||||
@functools.cache
|
||||
@@ -470,18 +471,19 @@ class Kernel:
|
||||
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
||||
|
||||
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
|
||||
tcd = self.first_upcast
|
||||
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
||||
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
||||
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
||||
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))] + \
|
||||
[f"u{i}" for i in range(len(tc.get_upcast_axes()))])])[::-1]
|
||||
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
|
||||
|
||||
# permute the srcs
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))):
|
||||
src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
|
||||
srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(tuple(permaxis)))
|
||||
|
||||
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
|
||||
tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
|
||||
# construct the op
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
||||
@@ -489,6 +491,7 @@ class Kernel:
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
||||
|
||||
# preserve any other reduce
|
||||
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
||||
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
|
||||
Reference in New Issue
Block a user