From 38dcadf07b4d1ac89015fff7ee24c59c429eeca1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:52:07 -0700 Subject: [PATCH] delete kernel.py (#12040) * delete kernel.py * delete that file * rip and tear * don't test search * imports * fix torch frontend * not a part of regen --- .github/workflows/test.yml | 6 +- extra/gemm/simple_matmul.py | 2 +- extra/optimization/helpers.py | 2 +- extra/to_movement_ops.py | 2 +- test/external/external_test_train_gpt2.py | 2 +- test/external/fuzz_linearizer.py | 2 +- .../external/process_replay/process_replay.py | 2 +- tinygrad/codegen/opt/heuristic.py | 102 ++-- tinygrad/codegen/opt/kernel.py | 435 ------------------ tinygrad/codegen/opt/search.py | 16 +- tinygrad/engine/realize.py | 2 +- tinygrad/viz/serve.py | 2 +- 12 files changed, 48 insertions(+), 527 deletions(-) delete mode 100644 tinygrad/codegen/opt/kernel.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8c12852c3e..21fba630f7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -372,7 +372,7 @@ jobs: CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus python extra/optimization/extract_dataset.py gzip -c /tmp/sops > extra/datasets/sops.gz - DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py + #DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py - name: Repo line count < 18000 lines run: MAX_LINE_COUNT=18000 python sz.py @@ -532,8 +532,8 @@ jobs: opencl: 'true' - name: Test ONNX (GPU) run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 - - name: Test Optimization Helpers - run: DEBUG=1 python3 extra/optimization/test_helpers.py + #- name: Test Optimization Helpers + # run: DEBUG=1 python3 extra/optimization/test_helpers.py #- name: Test Action Space # run: DEBUG=1 GPU=1 python3 extra/optimization/get_action_space.py - name: Test Beam Search diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 7b9c072785..0c91005a16 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -2,7 +2,7 @@ import numpy as np from tinygrad import dtypes, Tensor from tinygrad.helpers import getenv, get_single_element from tinygrad.dtype import _to_np_dtype -from tinygrad.codegen.opt.kernel import OptOps +from tinygrad.codegen.opt import OptOps from tinygrad.engine.realize import lower_schedule dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 94eacaa8d1..deff032ca6 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -1,6 +1,6 @@ # stuff needed to unpack a kernel from tinygrad import Variable -from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index 68d4ef3dfe..dcbf90d104 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -2,7 +2,6 @@ import itertools from enum import Enum, auto from collections import defaultdict from typing import List, Tuple, DefaultDict -from extra.optimization.helpers import load_worlds, ast_str_to_ast from tinygrad.helpers import prod, tqdm from tinygrad.uop.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker @@ -147,6 +146,7 @@ def test_rebuild_bufferop_st(ast:UOp): for src in ast.src: test_rebuild_bufferop_st(src) if __name__ == "__main__": + from extra.optimization.helpers import load_worlds, ast_str_to_ast ast_strs = load_worlds(False, False, True)[:2000] for ast_str in tqdm(ast_strs): test_rebuild_bufferop_st(ast_str_to_ast(ast_str)) diff --git a/test/external/external_test_train_gpt2.py b/test/external/external_test_train_gpt2.py index 4994ad4caa..196beff5a1 100644 --- a/test/external/external_test_train_gpt2.py +++ b/test/external/external_test_train_gpt2.py @@ -2,7 +2,7 @@ import unittest from tinygrad.uop.ops import UOp, Ops -from tinygrad.codegen.opt.search import Opt, OptOps +from .search import Opt, OptOps from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 53a45a51ac..8382ce2747 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -22,7 +22,7 @@ if os.getenv("VALIDATE_HCQ", 0) != 0: from tinygrad import Tensor, Device, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.codegen.opt.kernel import Kernel -from tinygrad.codegen.opt.kernel import Opt, OptOps +from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt.search import get_kernel_actions, bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 0a82ac31c5..5baf9f2015 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -12,7 +12,7 @@ try: from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.engine.realize import get_program from tinygrad.uop.ops import UOp, Ops, KernelInfo - from tinygrad.codegen.opt.kernel import Opt + from tinygrad.codegen.opt import Opt from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.device import Device except ImportError as e: diff --git a/tinygrad/codegen/opt/heuristic.py b/tinygrad/codegen/opt/heuristic.py index 5aed5ca3f6..d95a73f864 100644 --- a/tinygrad/codegen/opt/heuristic.py +++ b/tinygrad/codegen/opt/heuristic.py @@ -3,12 +3,9 @@ from tinygrad.codegen.opt import Opt, OptOps, KernelOptError from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX from tinygrad.dtype import ImageDType from tinygrad.uop.ops import Ops, resolve, AxisType - -# both versions -from tinygrad.codegen.opt.kernel import Kernel from tinygrad.codegen.opt.postrange import Scheduler -def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: +def hand_coded_optimizations(k:Scheduler) -> list[Opt]: # first try the tensor cores """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). @@ -38,15 +35,7 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: pass if good_tc_opt: # skip hand-coded TC opts if AMX, upcasting will make kernel slower - if isinstance(k, Kernel) and (tc_opts:=tk.tensor_core_opts) is not None and not AMX: - # hand-coded TC opts - for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N - szs = [sz for sz in [5,4,3,2] if tk.full_shape[tc_opts.axes[tc_dim]] % sz == 0] - if szs: tk.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0])) - - if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if tk.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N - tk.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0])) - elif isinstance(k, Scheduler) and rngs is not None and not AMX: + if rngs is not None and not AMX: for tc_dim in [1,0]: # attempt to upcast M and N szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None] if szs: @@ -64,32 +53,17 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \ (mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: - if isinstance(k, Kernel): - st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])] - strides0, strides1 = st0.real_strides(), st1.real_strides() - def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) - if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \ - not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)): - for global_idx in k.axes_of(AxisType.GLOBAL): - if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: - if DEBUG >= 3: - print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") - if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) - if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) - if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) - return k.applied_opts - else: - idx0, idx1 = mulop.src[0].src[0].src[1], mulop.src[1].src[0].src[1] - first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0] - if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges): - for global_idx in k.axes_of(AxisType.GLOBAL): - if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: - if DEBUG >= 3: - print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") - if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) - if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) - if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) - return k.applied_opts + idx0, idx1 = mulop.src[0].src[0].src[1], mulop.src[1].src[0].src[1] + first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0] + if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges): + for global_idx in k.axes_of(AxisType.GLOBAL): + if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: + if DEBUG >= 3: + print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") + if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) + if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) + if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) + return k.applied_opts # are we grouping? (requires local shape support) if resolve(prod(k.output_shape[i] for i in k.upcastable_dims) <= 2048, False): @@ -102,11 +76,8 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: # upcast float4 images for buf_index,buf in enumerate(k.bufs): if isinstance(buf.src[0].dtype, ImageDType): - if hasattr(k, "sts"): - unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0] - else: - # part of real_strides - unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0] + # part of real_strides + unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0] if len(unit_stride_axes_mul_4): if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims: k.apply_opt(Opt(OptOps.UPCAST, axis, 4)) @@ -122,11 +93,9 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: to_upcast: list[int] = [] # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) for axis in k.upcastable_dims: - if isinstance(k, Kernel): is_masked = any(st.axis_is_masked(axis) for st in k.sts) - else: - # for Schedule, we check if the range is used in INDEX gates or WHERE gates - is_masked = any(len(st.src) > 2 and k.rngs[axis] in st.src[2].parents for st in k.bufs) or \ - any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE) + # for Schedule, we check if the range is used in INDEX gates or WHERE gates + is_masked = any(len(st.src) > 2 and k.rngs[axis] in st.src[2].parents for st in k.bufs) or \ + any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE) if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: if DEBUG >= 4: print(f"upcasting masked axis : {axis}") to_upcast.append(axis) @@ -141,24 +110,16 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]): # if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue - if isinstance(k, Kernel): - # must have stride 0 on a view - # must have all non stride 0 on what's upcasted before - if any(st.views[-1].strides[axis] == 0 and \ - all(x != 0 for t,x in zip(k.axis_types, st.real_strides()) if t in (AxisType.UPCAST, AxisType.UNROLL)) for st in k.sts): - xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), - sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount)) - else: - rng = k.rngs[axis] - if any(rng not in b.src[1].parents and all(r2 in b.src[1].parents for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs): - num_strides, sum_strides = 0, 0 - for b in k.bufs: - if rng in b.src[1].parents: num_strides += 1 - for c in b.src[1].split_uop(Ops.ADD): - if c is rng: sum_strides += 1 - if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg - if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg - xb_choices.append((num_strides, sum_strides, axis, upcast_amount)) + rng = k.rngs[axis] + if any(rng not in b.src[1].parents and all(r2 in b.src[1].parents for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs): + num_strides, sum_strides = 0, 0 + for b in k.bufs: + if rng in b.src[1].parents: num_strides += 1 + for c in b.src[1].split_uop(Ops.ADD): + if c is rng: sum_strides += 1 + if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg + if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg + xb_choices.append((num_strides, sum_strides, axis, upcast_amount)) if xb_choices: xb_choices = sorted(xb_choices) if DEBUG >= 4: print(f"more upcast axis : {xb_choices}") @@ -196,11 +157,8 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]: k.apply_opt(Opt(OptOps.NOLOCALS)) else: # prioritize making expand axes local - if isinstance(k, Kernel): - local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)] - else: - local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].parents for b in k.bufs), axis) \ - for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST] + local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].parents for b in k.bufs), axis) \ + for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST] to_local: list[tuple[int, int]] = [] for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): local_size = prod(sz for _, sz in to_local) diff --git a/tinygrad/codegen/opt/kernel.py b/tinygrad/codegen/opt/kernel.py deleted file mode 100644 index 832a8e5fa6..0000000000 --- a/tinygrad/codegen/opt/kernel.py +++ /dev/null @@ -1,435 +0,0 @@ -from __future__ import annotations -import itertools, functools, math -from dataclasses import dataclass -from collections import defaultdict -from typing import cast, Final, Callable, Sequence -from tinygrad.codegen.opt import OptOps, Opt, KernelOptError, check, axis_letters, axis_colors -from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType -from tinygrad.uop.spec import type_verify, ast_spec -from tinygrad.device import Device -from tinygrad.codegen.opt.tc import TensorCore -from tinygrad.renderer import Renderer -from tinygrad.dtype import ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import strides_for_shape, get_contraction -from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load - -@dataclass -class TensorCoreOptions: - axes: tuple[int, ...] # the location of the original N and M axes if still in the shape - axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape - axis_pads: tuple[tuple[int, int], ...] - def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed - axes, axes_exist = list(self.axes), list(self.axes_exist) - for tc_dim in [i for i in range(2) if axes_exist[i]]: - if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1 - elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False - self.axes, self.axes_exist = tuple(axes), tuple(axes_exist) - -class Kernel: - def __init__(self, ast:UOp, opts:Renderer|None=None): - assert ast.op is Ops.SINK, ast.op - self.ast = ast - - self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer - # verify AST matches the spec - if __debug__: type_verify(list(self.ast.toposort()), ast_spec) - - self.vars: list[Variable] = self.ast.variables() - # NOTE: this requires a specific order with the [::-1], this is likely a bug - self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.st is not None][::-1] - - # create new shapetrackers inside this kernel, we will permute them - self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs] - - # add the shapetrackers for each reduce - # we use this to track which axes are reduced in each reduce - self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS] - for x in self.reduceops: - self.sts.append(unwrap(x.st)) - self.sts.append(unwrap(x.src[0].st)) - - # add a shapetracker to the end to track the full shape, with 0 strides so it can merge - full_shape = ast.full_shape - self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape))) - - # parameters for optimization - self.tensor_core: TensorCore|None = None - self.tensor_core_opts: TensorCoreOptions|None = None - self.use_tensor_cores: int = 0 - self.applied_opts: list[Opt] = [] - self.dont_use_locals = False - self.finalized: bool = False - - # group simplifies - self.simplify_ones() - self.simplify_merge_adjacent() - - # axis types - global_loops = AxisType.GLOBAL if self.opts.has_local else AxisType.LOOP - self.axis_types: list[AxisType] = [AxisType.REDUCE if resolve(x!=y) else global_loops for x,y in zip(self.output_shape, self.full_shape)] - - # confirm all reduce axes are at the end - if (final_reduces := [x for x in self.axis_types if x == AxisType.REDUCE]) and final_reduces != self.axis_types[-len(final_reduces):]: - raise RuntimeError(f"reduces are not at the end of the shape {self.full_shape} -> {self.output_shape}") - - def copy(self): - ret = type(self).__new__(type(self)) - - # base linearizer params - ret.opts, ret.ast = self.opts, self.ast - - # things downstream of the AST - ret.reduceops, ret.vars, ret.bufs = self.reduceops, self.vars, self.bufs - ret.sts = self.sts[:] - ret.axis_types = self.axis_types[:] - - # parameters for optimizations - ret.applied_opts, ret.dont_use_locals = self.applied_opts[:], self.dont_use_locals - ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores - ret.finalized = self.finalized - - return ret - - @property - def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None - @property - def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape - - @property - def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape - @property - def shape_len(self) -> int: return len(self.full_shape) - - def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)] - @property - def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL)) - @property - def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE)) - - # heuristic helpers - @property - def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \ - if isinstance(s:=self.full_shape[i], int) and s > 1] - @property - def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \ - if isinstance(s:=self.full_shape[i], int) and s > 1] - - # ******************** colors and names ******************** - - def colors(self) -> list[str]: - assert len(self.axis_types) == self.shape_len, "colors size mismatch" - return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types] - - def colored_shape(self, pad:int|None=None, dense=False) -> str: - shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape] - ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors())) - if pad: ret += ' '*(pad-ansilen(ret)) - return ret - - kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) - @functools.cached_property - def name(self) -> str: - # kernel name (before late upcast) - kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort()) else "E") - suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())]) - name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix - - # name the function something unique - Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1 - num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else "" - return name + colored(num, 'BLACK') - - # ******************** base simplifiers ******************** - - # apply reshape and permute to all shapetrackers - def reshape(self, new_shape_fxn:Callable[[tuple[sint, ...]], Sequence[sint]]): - self.sts = [st.reshape(tuple(new_shape_fxn(st.shape))) for st in self.sts] - def permute(self, new_axes:Sequence[int]): self.sts = [st.permute(tuple(new_axes)) for st in self.sts] - - # axis : the axis to pull from - # amount : the amount to take - # top : if you want to pull that amount from the top - # insert_at : place to insert the new stuff - def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None) -> int: - if insert_at is None: insert_at = self.shape_len - self.axis_types.insert(insert_at, new_type) - move_axis = axis if top else axis+1 - if move_axis < insert_at: insert_at += 1 - def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:] - new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis] - self.reshape(new_shape_fxn) - self.permute(new_axes) - return insert_at - - # ******************** complex simplifiers ******************** - - def simplify_ones(self) -> bool: - # remove places where the shape is all ones - if any(all_ones:=[s==1 for s in self.full_shape]): - if hasattr(self, 'axis_types'): - self.axis_types = [x for i,x in enumerate(self.axis_types) if not all_ones[i]] - self.reshape(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]]) - return True - return False - - def simplify_merge_adjacent(self): - assert not hasattr(self, 'axis_types'), "don't call this after init" - if self.shape_len == 0: return - shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] - # NOTE: we can't use self.first_reduce yet - first_reduce = [resolve(x!=y) for x,y in zip(self.output_shape+(0,), self.full_shape+(1,))].index(True) - - # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - # TODO: remove membufs - membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}]) - if isinstance(membufs[0].base.dtype, ImageDType): - base_shape = membufs[0].base.dtype.shape - if shape_idx_groups := get_contraction(self.output_shape, base_shape): - special_strides: tuple[sint, ...] = tuple() - for i,g in enumerate(shape_idx_groups): - shape_piece = tuple(self.output_shape[x] for x in g) - assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}" - special_strides += strides_for_shape(shape_piece) - # adding the fake image shape - shapes.append(self.output_shape) - strides.append(special_strides) - - # merge dimensions if we can, multi _merge_dims - # NOTE: this does not always preserve the reduce dimension - # TODO: move this into shapetracker, with tests! - # TODO: how does this work with multi-reduce? - rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)] - for i in range(1, len(shapes[0])): - can_merge = [] - for s,st,ret in zip(shapes, strides, rets): - # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - si, sti, last_st = s[i], st[i], ret[-1][1] - can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0))) - # more can merge than this - mergeable = all(can_merge) and i != first_reduce - for j,(s,st) in enumerate(zip(shapes, strides)): - if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i]) - else: rets[j].append((s[i], st[i])) - - # do the reshapes - for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) - - # ******************** apply optimizations ******************** - - def real_axis(self, op:OptOps, axis:int|None): - try: - if axis is None: return -1 - if op is OptOps.UNROLL: return self.unrollable_dims[axis] - if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis] - check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}") - return axis - except IndexError as e: raise KernelOptError from e - - def apply_opt(self, opt:Opt, append_opt:bool=True) -> int|None: - if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized") - if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals") - - if opt.op is OptOps.TC: - check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine - check(len(self.opts.tensor_cores) > 0, "must have tensor cores") - check(opt.axis is not None, "tensor core opts must have an axis") - check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg") - check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select") - check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt") - check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid") - check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available") - self.applied_opts.append(opt) - return None - - axis = self.real_axis(opt.op, opt.axis) - - if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs - elif opt.arg is not None: - check(isinstance(opt.arg, int), "arg should be int") - amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis] - check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless") - if opt.op is not OptOps.PADTO: - # we check both the full_shape and each shape - check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}") - for st in self.sts: check(st.shape[axis] == 1 or st.shape[axis] % amt == 0, f"no longer valid shift {st.shape[axis]=}, {amt=}") - else: amt = -1 - - if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ - (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): - acc_sz = self.reduceop.dtype.itemsize - upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)]) - local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)]) - smem_sz = amt*acc_sz*upcast_sz*local_sz - check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") - - new_axis = None - if opt.op is OptOps.LOCAL: # cyan - # NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache) - # it's disabled for now since it makes BEAM slow for little gain - check(self.opts.has_local, "target does not support local") - check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals") - new_axis = self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1) - elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green - check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") - check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group") - check(not self.tensor_core, "can't group with tensor cores") - check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces") - new_axis = self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_at=min(self.axes_of(AxisType.REDUCE))) - elif opt.op is OptOps.UNROLL: # purple - check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted") - check(amt <= 32, "don't unroll more than 32") - new_axis = self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None) - elif opt.op is OptOps.UPCAST: # yellow - check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}") - # NOTE: assume the first get_local_axes() LOCAL are for TC - check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals") - check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16") - new_axis = self.shift_to(axis, amt, AxisType.UPCAST, - insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1) - elif opt.op is OptOps.NOLOCALS: - check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals") - check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals") - self.dont_use_locals = True - elif opt.op is OptOps.SWAP: - check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}") - check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}") - permute = list(range(self.shape_len)) - permute[axis], permute[amt] = permute[amt], permute[axis] - self.permute(tuple(permute)) - elif opt.op is OptOps.PADTO: - check(not self.vars, "does not work with symbolic shape") - check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "cannot pad upcasted") - # ok to pad SUM if all parent ALU ops have f(0) = 0 - if (r:=self.reduceop) is not None and self.axis_types[axis] in (AxisType.GROUP_REDUCE, AxisType.REDUCE): - check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}") - padded = False - for i,st in enumerate(self.sts): - if (s:=st.shape[axis]) == 1: continue # reduced - check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}") - if (ru := round_up(cast(int, s), amt) - s): - # pad right seems to be faster - self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1)) - padded = True - check(padded, "nothing was padded") - - if append_opt: self.applied_opts.append(opt) - if self.simplify_ones() and self.tensor_core_opts: - self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones() - return new_axis - - def apply_opts(self, opts:Sequence[Opt]) -> Kernel: - for opt in opts: self.apply_opt(opt) - return self - - # **** kernel outputs, mostly tensor cores **** - - def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> TensorCoreOptions|None: - has_cast = tc.dtype_in != tc.dtype_out - if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None - - mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] - if mul_op.op is not Ops.MUL: return None - - def buf_index(src:UOp) -> int|None: - # TODO: apply tc even if the sources are not from LOAD - if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src) - try: - if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0]) - except ValueError: return None - return None - if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None - - buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0] - if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None - - axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE))) - if not (axis < len(axis_choices)): return None - - s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k - axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0)) - if axis_pads and (opt_level < 2): return None - if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) - return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads) - - def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool: - if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD: - tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]] - for tc in tensor_cores: - tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] - if tensor_core_opts[0] is None: continue - # can only fuse reduces with the same tc options - assert all_same(tensor_core_opts) - self.tensor_core_opts = tc_opts = tensor_core_opts[0] - - # attempt to pad the tensor axes that require it - try: - for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail - except KernelOptError: continue - # tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M) - for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False) - for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0 - self.tensor_core = tc - self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA - return True - return False - - # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2'] - def shape_str(self) -> list[str]: - ret: list[str] = [] - cnt: dict[AxisType, int] = {} - for x in self.axis_types: - cnt[x] = (cnt[x] + 1) if x in cnt else 0 - ret.append(f"{axis_letters[x]}{cnt[x]}") - return ret - def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms]) - - def get_optimized_ast(self, name_override:str|None=None) -> UOp: - @functools.cache - def fixup_ast(op:UOp) -> UOp: - ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821 - if op.op in GroupOp.Buffer and op in self.bufs: - st = self.sts[self.bufs.index(op)] - # replace the VIEW source - return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:]) - if op.op is Ops.SINK: - # NOTE: should group_for_reduces be added to the local_dims? - # TODO: arg.name should be able to be None - kernel_name = ret.arg.name if ret.arg is not None and ret.arg.name != "test" else self.name if name_override is None else name_override - return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts))) - if op.op is Ops.REDUCE_AXIS: - reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2 - changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i])) - axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.GROUP_REDUCE, AxisType.UNROLL) if i in changed) - if (tc := self.tensor_core) and self.use_tensor_cores == 1: - # get reduce/upcast axes for the tensor cores - tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))]) - base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())]) - tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)]) - - # permute the srcs - srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src) - for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))): - src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg - srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(permaxis)) - - # construct the op - wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes) - wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=( - UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]), - UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]), - UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg) - tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2]) - - # preserve any other reduce - return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop - - ret = ret.replace(arg = (op.arg[0], axes)) - return ret - self.finalized = True - fixed_ast = fixup_ast(self.ast) - del fixup_ast - return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST") diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 19dfa78b6e..953118cb90 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -12,8 +12,6 @@ from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner, get_program from tinygrad.renderer import ProgramSpec -# both versions -from tinygrad.codegen.opt.kernel import Kernel from tinygrad.codegen.opt.postrange import Scheduler actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)] @@ -63,7 +61,7 @@ def timeout_handler(signum, frame): if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT") raise TimeoutException() -def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]: +def _try_compile_linearized_w_idx(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]: if hasattr(signal, "alarm"): signal.signal(getattr(signal, 'SIGALRM'), timeout_handler) # set timeout @@ -98,7 +96,7 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_ # get (scrap) buffers for timing the linearizer # NOTE: there's also bufs_from_ast in postrange -def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: +def bufs_from_lin(lin:Scheduler, allocate:bool=True) -> list[Buffer]: bufsts: defaultdict[int, list[UOp]] = defaultdict(list) for x in lin.bufs: if x.src[0].base.op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].base.arg].append(x) @@ -115,7 +113,7 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: return cast(list[Buffer], rawbufs) # get dictionary of all possible actions -def get_kernel_actions(lin:Kernel|Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Kernel|Scheduler]: +def get_kernel_actions(lin:Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Scheduler]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) kernel_actions = (actions if candidates is None else candidates).copy() @@ -139,7 +137,7 @@ def get_kernel_actions(lin:Kernel|Scheduler, include_0=True, candidates:list[Opt return acted_lins beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") -def beam_search(lin:Kernel|Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value): +def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value): global beam_pool key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None: @@ -147,7 +145,7 @@ def beam_search(lin:Kernel|Scheduler, rawbufs:list[Buffer], amt:int, allow_test_ for o in val[len(lin.applied_opts):]: ret.apply_opt(o) return ret - beam: list[tuple[Kernel|Scheduler, float]] = [(lin, float("inf"))] + beam: list[tuple[Scheduler, float]] = [(lin, float("inf"))] seen_libs = set() default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0 @@ -168,8 +166,8 @@ def beam_search(lin:Kernel|Scheduler, rawbufs:list[Buffer], amt:int, allow_test_ exiting, st = False, time.perf_counter() dev = Device[lin.opts.device] while not exiting: - acted_lins: list[Kernel|Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam]) - timed_lins: list[tuple[Kernel|Scheduler, float]] = [] + acted_lins: list[Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam]) + timed_lins: list[tuple[Scheduler, float]] = [] _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler) least_compute_ops = math.inf for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))): diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 8e622c07f0..b1642a284a 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -8,7 +8,7 @@ from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.engine.schedule import ScheduleItem from tinygrad.codegen import full_rewrite -from tinygrad.codegen.opt.kernel import Opt +from tinygrad.codegen.opt import Opt # **************** Program Creation **************** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index ab2d6d6335..f9d3658eb4 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -11,7 +11,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp, from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes -from tinygrad.codegen.opt.kernel import axis_colors +from tinygrad.codegen.opt import axis_colors uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",