mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
Kernel.required_optimizations and Kernel.hand_coded_optimizations returns self (#5576)
[run_process_replay]
This commit is contained in:
@@ -480,7 +480,7 @@ class Kernel:
|
||||
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
||||
self.dont_use_locals = True
|
||||
elif opt.op is OptOps.SWAP:
|
||||
check(axis < amt and amt < self.global_dims, "swap is only for globals with axis < amt")
|
||||
check(axis < amt and amt < self.global_dims, f"swap is only for globals with axis < amt, getting {amt=}, {axis=}, {self.global_dims=}")
|
||||
permute = list(range(self.shape_len))
|
||||
permute[axis], permute[amt] = permute[amt], permute[axis]
|
||||
self.reshape_and_permute(None, tuple(permute))
|
||||
@@ -511,14 +511,15 @@ class Kernel:
|
||||
if self.simplify_ones() and self.tensor_core_opts:
|
||||
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
||||
|
||||
def required_optimizations(self):
|
||||
def required_optimizations(self) -> Kernel:
|
||||
if self.bufs[0].dtype.__class__ is ImageDType:
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
|
||||
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
||||
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
|
||||
return self
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
def hand_coded_optimizations(self) -> Kernel:
|
||||
self.required_optimizations()
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
@@ -539,7 +540,7 @@ class Kernel:
|
||||
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
# SWAP global
|
||||
if self.global_dims >= 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
|
||||
return
|
||||
return self
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
@@ -574,7 +575,7 @@ class Kernel:
|
||||
if self.group_for_reduces:
|
||||
# SWAP global
|
||||
if self.global_dims >= 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
|
||||
return
|
||||
return self
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
@@ -651,6 +652,8 @@ class Kernel:
|
||||
# SWAP global
|
||||
if self.global_dims >= 3: self.apply_opt(Opt(OptOps.SWAP, 0, self.global_dims-1))
|
||||
|
||||
return self
|
||||
|
||||
# **** kernel outputs ****
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
|
||||
@@ -16,22 +16,18 @@ logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS
|
||||
def get_kernel(renderer:Renderer, ast:LazyOp) -> Kernel:
|
||||
if DEBUG >= 5:
|
||||
print(ast)
|
||||
k = Kernel(ast, opts=renderer)
|
||||
k.required_optimizations()
|
||||
k = Kernel(ast, opts=renderer).required_optimizations()
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
|
||||
kb, k_opt = Kernel(ast, opts=renderer), k
|
||||
kb.required_optimizations()
|
||||
kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
if beam_compare:=getenv("BEAM_COMPARE", 1):
|
||||
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
|
||||
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Kernel(ast, opts=renderer)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
|
||||
Reference in New Issue
Block a user