diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index d23a775860..eb62cd2968 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -6,11 +6,11 @@ from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callab from enum import Enum, auto from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \ - graph_rewrite, track_rewrites, UPat + graph_rewrite, track_rewrites from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec from tinygrad.dtype import ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap +from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape @@ -607,83 +607,97 @@ class Kernel: return name + colored(num, 'BLACK') def get_optimized_ast(self) -> UOp: + # set the shapetrackers to the optimized ones, fixup reduceop + # transformed to the final UOp @functools.lru_cache(None) - def fixup_ast(op:UOp) -> UOp: - ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) - if op.op in GroupOp.Buffer and op in self.bufs: - st_uop = self.sts[self.bufs.index(op)].to_uop() - return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:])) - if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)) + def fixup_ast(op:UOp, apply_to_st=None) -> UOp: + arg = op.arg + if op.op in GroupOp.Buffer: + # for locals, we use the ShapeTracker that's in the srcs + st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)] + st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() + if op.op is Ops.VALID: return op.replace(src=(st_uop,)) + if op.op is Ops.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) + return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) if op.op is Ops.REDUCE_AXIS: - reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2 + reduce_idx = len(self.bufs) + self.reduceops.index(op)*2 + alu_op: Ops = op.arg[0] + axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len) + if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])) + if op in self.bufs_for_tensor_core and (tc := self.tensor_core): + rsrc = op.src[0] + if rsrc.op is Ops.CAST: rsrc = rsrc.src[0] + assert rsrc.op is Ops.MUL - def reduced_axes(start, stop): - return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i])) - axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len) - grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces) + def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1): + wd, tcd = self.global_dims, self.first_upcast + assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}" + assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}" + new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd + permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in pattern_1] + list(range(wd+len(warp_dims), tcd)) + \ + [y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape))) + return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify() - if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3): - def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern): - wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads) - tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes) + warp_dims = tuple(sz for _, sz in tc.threads) + tcd_dims = tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes) + fix_st1 = functools.partial(fix_st, warp_dims, tcd_dims, tc.expanded_shape, *tc.st1_pattern) if tc.st1_pattern else None + fix_st2 = functools.partial(fix_st, warp_dims, tcd_dims, tc.expanded_shape, *tc.st2_pattern) if tc.st2_pattern else None - assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}" - assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}" - assert tc.expanded_shape is not None - - new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd - permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \ - [y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape))) - return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify() - - srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) - for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]): - if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern)) - - if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals - local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape)) - st = store_st = ShapeTracker.from_shape(local_shape) - local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size())) - if tc_pattern: store_st = fix_st(store_st, *tc_pattern) - local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i]) - srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store)) - - tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes) - if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/EXPAND to get the vectorization right - upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes) - wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes) - wmma_sz = [prod(x[1] for x in l) for l in upcast_axes] + assert apply_to_st is None, "double tensor core? not supported" + wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(t[1] for t in tc.threads), + tuple(tuple((self.first_upcast+ax,sz) for ax,sz in up) for up in tc.upcast_axes), tuple(self.first_upcast+ax for ax,_ in tc.reduce_axes)) + if self.use_tensor_cores >= 2: + if self.use_tensor_cores == 3: + # TC=3, emulate the warp addressing with locals + ex_shape = tuple(1 if i < self.global_dims or (i >= self.first_reduce and i < self.first_upcast) else s \ + for i,s in enumerate(self.full_shape)) + srcs = [] + for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])): + st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is Ops.LOAD] + local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) + st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() + membuf = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) + local_store = fixup_ast(UOp(Ops.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) + srcs.append(UOp(Ops.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) + else: + # for TC=2, we can't do the shapetracker fixup + srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])] + # MUL/SUM instead of WMMA + ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(Ops.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) + else: + # real WMMA, use CONTRACT/EXPAND to get the vectorization right + wmma_upcast_axes = wmma_arg[-2] + wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes] wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=( - UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]), - UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]), + UOp(Ops.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]), + UOp(Ops.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]), UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg) - tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2]) - - else: # for TC=3 MUL/SUM instead of WMMA - tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes)) - - new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes) - return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop - - ret = ret.replace(arg = (op.arg[0], axes)) - if self.group_for_reduces and grouped_axes: + ret = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2]) + new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in [ax for ax, _ in tc.reduce_axes]) + return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret + if self.group_for_reduces: + start = UOp(Ops.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis)) + second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \ + if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]) + # NOTE: if there's a grouped reduce, but no reduce axes for this reduce, we can skip it + if len(second_axis) == 0: return start local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \ tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \ for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) - local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret))) - grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes)) + local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) + grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) if op is self.reduceops[-1]: return grouped_reduce - st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop() + st_uop = ShapeTracker.from_shape(tuple([1 if i in second_axis else a for i,a in enumerate(local_shape)])).to_uop() return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce))) - - return ret - - return graph_rewrite(fixup_ast(self.ast), PatternMatcher([ - (UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))), - (UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))])) + arg = (alu_op, axis) + elif op.op is Ops.SINK: + arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals) + return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg) + # NOTE: rewrite with an empty PatternMatcher to dedup UOps + return graph_rewrite(fixup_ast(self.ast), PatternMatcher([])) # **** this is the lowerer ****