diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index cd8e8196ac..d6a81371fc 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -480,7 +480,7 @@ class Kernel: check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals") self.dont_use_locals = True elif opt.op is OptOps.SWAP: - check(axis < amt and amt < self.global_dims, "swap is only for globals with axis < amt") + check(axis < amt and amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}") permute = list(range(self.shape_len)) permute[axis], permute[amt] = permute[amt], permute[axis] self.reshape_and_permute(None, tuple(permute)) @@ -511,14 +511,15 @@ class Kernel: if self.simplify_ones() and self.tensor_core_opts: self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones() - def required_optimizations(self): + def required_optimizations(self) -> Kernel: if self.bufs[0].dtype.__class__ is ImageDType: unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0] assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}" if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501 self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + return self - def hand_coded_optimizations(self): + def hand_coded_optimizations(self) -> Kernel: self.required_optimizations() # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat @@ -539,7 +540,7 @@ class Kernel: if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) # SWAP global if self.global_dims >= 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1)) - return + return self if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]): # are we grouping? (requires local shape support) @@ -574,7 +575,7 @@ class Kernel: if self.group_for_reduces: # SWAP global if self.global_dims >= 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1)) - return + return self # **** below this line need to be optional and benchmarked **** @@ -651,6 +652,8 @@ class Kernel: # SWAP global if self.global_dims >= 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1)) + return self + # **** kernel outputs **** kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 2e1287fc74..2c84883e4c 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -16,22 +16,18 @@ logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS def get_kernel(renderer:Renderer, ast:LazyOp) -> Kernel: if DEBUG >= 5: print(ast) - k = Kernel(ast, opts=renderer) - k.required_optimizations() + k = Kernel(ast, opts=renderer).required_optimizations() if not NOOPT: if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() if BEAM >= 1: from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin - kb, k_opt = Kernel(ast, opts=renderer), k - kb.required_optimizations() + kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k rawbufs = bufs_from_lin(kb, allocate=False) k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) if beam_compare:=getenv("BEAM_COMPARE", 1): # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)] - if used_tensor_cores: - lins.append(("hc", Kernel(ast, opts=renderer))) - lins[-1][1].hand_coded_optimizations() + if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations())) timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) k = timed[0][1]