move hand_coded_optimizations to heuristic.py [pr] (#9844)

* move hand_coded_optimizations to heuristic.py [pr]

also folded all long lines

* make a copy and rename self -> k

* fix test
This commit is contained in:
chenyu
2025-04-10 23:40:16 -04:00
committed by GitHub
parent e0ec8be37d
commit 8c6299bced
15 changed files with 168 additions and 149 deletions

View File

@@ -3,6 +3,7 @@ from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.engine.search import beam_search, bufs_from_lin
@@ -83,7 +84,7 @@ if __name__ == "__main__":
# always try hand coded opt
lin = Kernel(si.ast, opts=device.renderer)
lin.hand_coded_optimizations()
lin = hand_coded_optimizations(lin)
lins.append((lin, "HC"))
# maybe try tensor cores

View File

@@ -40,7 +40,6 @@ sched = C.schedule()
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import CompilerOptions
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
#lin.hand_coded_optimizations()
lin.linearize()
from tinygrad.runtime.ops_cpu import renderer
src = renderer("mmult", lin.uops)

View File

@@ -2,6 +2,7 @@ import random
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import actions
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.helpers import tqdm
tactions = set()
@@ -23,7 +24,7 @@ if __name__ == "__main__":
for ast_str in tqdm(ast_strs):
lin = ast_str_to_lin(ast_str)
#if not lin.apply_tensor_cores():
lin.hand_coded_optimizations()
lin = hand_coded_optimizations(lin)
test_rebuild(lin)
# confirm linearize can be called twice
uops1 = lin.linearize().uops

View File

@@ -7,6 +7,7 @@ from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions
from tinygrad.codegen.heuristic import hand_coded_optimizations
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet
@@ -34,7 +35,7 @@ if __name__ == "__main__":
rawbufs = bufs_from_lin(lin)
linhc = deepcopy(lin)
linhc.hand_coded_optimizations()
linhc = hand_coded_optimizations(linhc)
tmhc = time_linearizer(linhc, rawbufs)
print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape())

View File

@@ -7,6 +7,7 @@ from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.codegen.heuristic import hand_coded_optimizations
import numpy as np
def move_jit_captured_to_dev(captured, device="DSP"):
@@ -56,7 +57,7 @@ if __name__ == "__main__":
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
GlobalCounters.kernel_count -= 1
if not getenv("NOOPT"): k.hand_coded_optimizations()
if not getenv("NOOPT"): k = hand_coded_optimizations(k)
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei.run()

View File

@@ -1,11 +1,12 @@
import random
from tinygrad.helpers import getenv
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.codegen.heuristic import hand_coded_optimizations
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
def optimize_kernel(k):
# TODO: update this
return k.hand_coded_optimizations()
return hand_coded_optimizations(k)
if __name__ == '__main__':
hcopt_wins = beam_wins = tie = 0

View File

@@ -4,6 +4,7 @@ from tinygrad import Tensor, Device, nn
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
from tinygrad.ops import Ops
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite
@@ -37,7 +38,7 @@ if __name__ == "__main__":
if BEAM:
with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value)
elif NOOPT: pass
else: k.hand_coded_optimizations()
else: k = hand_coded_optimizations(k)
kernels.append(k)
with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels]

View File

@@ -1,6 +1,7 @@
from tinygrad import Device
from tinygrad.helpers import getenv, DEBUG, BEAM
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.codegen.heuristic import hand_coded_optimizations
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
if __name__ == "__main__":
@@ -20,14 +21,14 @@ if __name__ == "__main__":
k = new_lin()
# k.required_optimizations()
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k = hand_coded_optimizations(k)
assert BEAM > 0
lins = [(("tc" if used_tensor_cores else "hc"), k)]
if used_tensor_cores:
lins.append(("hc", new_lin()))
lins[-1][1].hand_coded_optimizations()
lins[-1][1] = hand_coded_optimizations(lins[-1][1])
kb = new_lin()
# kb.required_optimizations()
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization

View File

@@ -2,6 +2,7 @@ from tinygrad import Device, dtypes
from tinygrad.helpers import getenv, colorize_float
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from test.external.fuzz_linearizer import get_fuzz_rawbufs
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.engine.search import bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.tensor import _to_np_dtype
@@ -21,7 +22,7 @@ if __name__ == "__main__":
for num,ast in enumerate(ast_strs):
# cuda compile
culin = ast_str_to_lin(ast, opts=cudev.renderer)
culin.hand_coded_optimizations()
culin = hand_coded_optimizations(culin)
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.membufs)
cuda_prg = CompiledRunner(culin.to_program())
@@ -31,7 +32,7 @@ if __name__ == "__main__":
rdr = nvdev.renderer
rdr.device = "NV"
nvlin = ast_str_to_lin(ast, opts=rdr)
nvlin.hand_coded_optimizations()
nvlin = hand_coded_optimizations(nvlin)
nv_prg = CompiledRunner(nvlin.to_program())
nvbufs = bufs_from_lin(nvlin)
test_nvbufs = get_fuzz_rawbufs(nvlin) if not has_bf16 else nvbufs

View File

@@ -1,6 +1,7 @@
import itertools
from tinygrad import Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.helpers import getenv, colorize_float
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import bufs_from_lin
@@ -23,7 +24,7 @@ if __name__ == "__main__":
# cuda compile
dev.compiler = CUDACompiler(dev.arch)
lin = ast_str_to_lin(ast, opts=dev.renderer)
lin.hand_coded_optimizations()
lin = hand_coded_optimizations(lin)
cuda_prg = CompiledRunner(lin.to_program())
bufs = bufs_from_lin(lin)
@@ -31,7 +32,7 @@ if __name__ == "__main__":
# ptx compile
dev.compiler = PTXCompiler(dev.arch)
lin = ast_str_to_lin(ast, opts=ptx)
lin.hand_coded_optimizations()
lin = hand_coded_optimizations(lin)
lin.linearize()
ptx_prg = CompiledRunner(lin.to_program())

View File

@@ -12,6 +12,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
from tinygrad.dtype import DType, dtypes
@@ -1000,7 +1001,7 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Kernel(r.schedule()[-1].ast)
k.hand_coded_optimizations()
k = hand_coded_optimizations(k)
k.linearize()
stores = [u for u in k.uops if u.op is Ops.STORE]
@@ -1305,7 +1306,7 @@ class TestLinearizer(unittest.TestCase):
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin = hand_coded_optimizations(lin)
lin.linearize()
assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded"
@@ -1463,7 +1464,7 @@ class TestFloat4(unittest.TestCase):
s = c.schedule()[0]
k = Kernel(s.ast)
k.hand_coded_optimizations()
k = hand_coded_optimizations(k)
k.linearize()
assert TestFloat4.count_float4(k) == (2, 1)
@@ -1518,7 +1519,7 @@ class TestFloat4(unittest.TestCase):
s = c.schedule()[0]
k = Kernel(s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k = hand_coded_optimizations(k) # implicit trigger float4 dim
k.linearize()
assert TestFloat4.count_float4(k) == (0, 1)
@@ -1731,7 +1732,7 @@ class TestHandCodedOpts(unittest.TestCase):
s = layer_2.schedule()[-1]
k = Kernel(s.ast)
k.hand_coded_optimizations()
k = hand_coded_optimizations(k)
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7
# masked upcast should not upcast large (20) last axis
@@ -1744,7 +1745,7 @@ class TestHandCodedOpts(unittest.TestCase):
s = monster.schedule()[-1]
k = Kernel(s.ast)
k.hand_coded_optimizations()
k = hand_coded_optimizations(k)
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
@@ -1760,7 +1761,7 @@ class TestHandCodedOpts(unittest.TestCase):
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Kernel(si.ast)
k.hand_coded_optimizations()
k = hand_coded_optimizations(k)
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end)
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
@@ -1772,7 +1773,7 @@ class TestHandCodedOpts(unittest.TestCase):
backward_schedule = Tensor.schedule(x.grad, w.grad)
for si in backward_schedule:
k = Kernel(si.ast)
k.hand_coded_optimizations()
k = hand_coded_optimizations(k)
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done
@@ -1859,8 +1860,8 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
# Check correctness of handcoded optimiztions.
k = Kernel(realized_ast)
k = hand_coded_optimizations(k)
lins.append(k)
k.hand_coded_optimizations()
prg = get_prg(k)
reset_bufs(outbufs)
prg.exec(real_bufs)

View File

@@ -3,6 +3,7 @@ from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.ops import Ops
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
class TestWinograd(unittest.TestCase):
def setUp(self):
@@ -26,7 +27,7 @@ class TestWinograd(unittest.TestCase):
ops = s.ast.toposort
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Kernel(s.ast)
l.hand_coded_optimizations()
l = hand_coded_optimizations(l)
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")

View File

@@ -0,0 +1,132 @@
import itertools
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.helpers import getenv, DEBUG, all_int, prod
from tinygrad.dtype import ImageDType
from tinygrad.ops import Ops, resolve
def hand_coded_optimizations(k:Kernel) -> Kernel:
# make a copy so it does not mutate the input
k = k.copy().required_optimizations()
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
if strides0[k.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
for global_idx in range(k.global_dims):
if k.full_shape[k.first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3:
print(f"MATVEC: {k.full_shape=} {k.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
return k
if k.opts.has_local and k.opts.has_shared and all_int(k.sts[0].shape[:k.first_reduce]):
# are we grouping? (requires local shape support)
if not [x for x in k.sts[0].unit_stride_axes() if x >= k.first_upcast and k.sts[0].shape[x]%4 == 0] and \
k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]):
if all(st.shape[k.first_reduce] % sz == 0 or st.shape[k.first_reduce] == 1 for st in k.sts):
try: # may fail due to excessive smem usage
k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
# upcast float4 images
for buf_index,buf in enumerate(k.bufs):
unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]
if buf.src[0].dtype.__class__ is ImageDType:
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {k.bufs[buf_index]}"
if len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4):
if unit_stride_axes_mul_4[0] < k.first_reduce:
k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4))
# no more opt if we are grouping
if k.group_for_reduces: return k
# **** below this line need to be optional and benchmarked ****
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
# to trigger the above bug, remove prod(k.full_shape[k.first_upcast:]) from the below
# expression and run test/test_ops.py with IMAGE=2
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast: list[int] = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(k.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
prod(k.full_shape[k.first_upcast:]) * prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
is_dsp = k.opts is not None and k.opts.device == "DSP"
upcasted_axis: set[int] = set()
while resolve(prod(k.sts[0].shape[:k.first_reduce]) >= 1024):
xb_choices = []
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
for axis, upcast_amount in itertools.product(range(k.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and \
any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index, st in enumerate(k.sts)):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if xb_choices:
xb_choices = sorted(xb_choices)
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else: break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
if k.first_reduce < k.first_upcast and (prod(k.full_shape[k.first_upcast:]) <= 4 or \
not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))) and (k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64):
if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0))
# if it's small, upcast a second reduce dimension too
if k.first_reduce < k.first_upcast and s <= 3 and isinstance(s2:=k.full_unupcasted_shape[-1], int) and s2 <= 3:
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0))
else:
for splits in [4]:
if k.full_unupcasted_shape[-1]%splits == 0:
k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, splits))
break
# if nothing at all is upcasted and it's easy to, do an upcast
for splits in [4]:
if k.upcasted == 0 and k.full_unupcasted_shape and k.full_unupcasted_shape[-1] % splits == 0:
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_unupcasted_shape)-1, splits))
# **** local groups ****
if k.opts.has_local:
if getenv("NOLOCALS") and k.local_dims == 0 and not k.group_for_reduces:
k.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(k.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(k.sts))), axis) \
for axis in range(len(k.full_shape[:k.first_reduce]))]
to_local: list[tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None)
if local_sz is not None: to_local.append((axis, local_sz))
deleted_shape = 0
for axis, local_sz in sorted(to_local[:3]):
axis = axis - deleted_shape
will_delete_shape = local_sz == k.full_shape[axis]
k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1
return k

View File

@@ -4,8 +4,7 @@ from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, cast, Final, Callable, Sequence
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops
from tinygrad.ops import PatternMatcher
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops, PatternMatcher
from tinygrad.spec import type_verify, shape_spec
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps
@@ -438,129 +437,6 @@ class Kernel:
if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
return self
def hand_coded_optimizations(self) -> Kernel:
self.required_optimizations()
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
for global_idx in range(self.global_dims):
if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
if DEBUG >= 3:
print(f"MATVEC: {self.full_shape=} {self.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
return self
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
# are we grouping? (requires local shape support)
if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \
self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
try: # may fail due to excessive smem usage
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
except KernelOptError: pass
# upcast float4 images
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0]
if buf.src[0].dtype.__class__ is ImageDType:
#assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4):
if unit_stride_axes_mul_4[0] < self.first_reduce:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4))
# no more opt if we are grouping
if self.group_for_reduces: return self
# **** below this line need to be optional and benchmarked ****
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
# to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
# expression and run test/test_ops.py with IMAGE=2
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast: list[int] = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(self.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
is_dsp = self.opts is not None and self.opts.device == "DSP"
upcasted_axis: set[int] = set()
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
xb_choices = []
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
if xb_choices:
xb_choices = sorted(xb_choices)
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else: break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
if isinstance(s:=self.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
# if it's small, upcast a second reduce dimension too
if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
else:
for splits in [4]:
if self.full_unupcasted_shape[-1]%splits == 0:
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits))
break
# if nothing at all is upcasted and it's easy to, do an upcast
# TODO: this is breaking the tests
for splits in [4]:
if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0:
self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits))
# **** local groups ****
if self.opts.has_local:
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces:
self.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501
to_local: list[tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501
if local_sz is not None: to_local.append((axis, local_sz))
deleted_shape = 0
for axis, local_sz in sorted(to_local[:3]):
axis = axis - deleted_shape
will_delete_shape = local_sz == self.full_shape[axis]
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape: deleted_shape += 1
return self
# **** kernel outputs ****
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)

View File

@@ -7,6 +7,7 @@ from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.heuristic import hand_coded_optimizations
from tinygrad.engine.schedule import ScheduleItem
# **************** Program Creation ****************
@@ -15,7 +16,7 @@ logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS
def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
k = Kernel(ast, opts=renderer).required_optimizations()
if not NOOPT:
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
if not k.apply_tensor_cores(getenv("TC", 1)): k = hand_coded_optimizations(k)
if BEAM >= 1:
from tinygrad.engine.search import beam_search, bufs_from_lin
kb = Kernel(ast, opts=renderer).required_optimizations()