mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
delete kernel.py (#12040)
* delete kernel.py * delete that file * rip and tear * don't test search * imports * fix torch frontend * not a part of regen
This commit is contained in:
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -372,7 +372,7 @@ jobs:
|
||||
CAPTURE_PROCESS_REPLAY=1 python test/test_tiny.py TestTiny.test_plus
|
||||
python extra/optimization/extract_dataset.py
|
||||
gzip -c /tmp/sops > extra/datasets/sops.gz
|
||||
DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
|
||||
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
|
||||
- name: Repo line count < 18000 lines
|
||||
run: MAX_LINE_COUNT=18000 python sz.py
|
||||
|
||||
@@ -532,8 +532,8 @@ jobs:
|
||||
opencl: 'true'
|
||||
- name: Test ONNX (GPU)
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test Optimization Helpers
|
||||
run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
#- name: Test Optimization Helpers
|
||||
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
#- name: Test Action Space
|
||||
# run: DEBUG=1 GPU=1 python3 extra/optimization/get_action_space.py
|
||||
- name: Test Beam Search
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, get_single_element
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad.codegen.opt.kernel import OptOps
|
||||
from tinygrad.codegen.opt import OptOps
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad import Variable
|
||||
from tinygrad.codegen.opt.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
@@ -2,7 +2,6 @@ import itertools
|
||||
from enum import Enum, auto
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, DefaultDict
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -147,6 +146,7 @@ def test_rebuild_bufferop_st(ast:UOp):
|
||||
for src in ast.src: test_rebuild_bufferop_st(src)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
ast_strs = load_worlds(False, False, True)[:2000]
|
||||
for ast_str in tqdm(ast_strs):
|
||||
test_rebuild_bufferop_st(ast_str_to_ast(ast_str))
|
||||
|
||||
2
test/external/external_test_train_gpt2.py
vendored
2
test/external/external_test_train_gpt2.py
vendored
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.codegen.opt.search import Opt, OptOps
|
||||
from .search import Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -22,7 +22,7 @@ if os.getenv("VALIDATE_HCQ", 0) != 0:
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.opt.search import get_kernel_actions, bufs_from_lin
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
|
||||
|
||||
@@ -12,7 +12,7 @@ try:
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.codegen.opt.kernel import Opt
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.device import Device
|
||||
except ImportError as e:
|
||||
|
||||
@@ -3,12 +3,9 @@ from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, NOLOCALS, TC_OPT, TC_SELECT, USE_TC, AMX
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.uop.ops import Ops, resolve, AxisType
|
||||
|
||||
# both versions
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
def hand_coded_optimizations(k:Scheduler) -> list[Opt]:
|
||||
# first try the tensor cores
|
||||
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
||||
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
||||
@@ -38,15 +35,7 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
pass
|
||||
if good_tc_opt:
|
||||
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
if isinstance(k, Kernel) and (tc_opts:=tk.tensor_core_opts) is not None and not AMX:
|
||||
# hand-coded TC opts
|
||||
for tc_dim in [tc_dim for tc_dim in [1,0] if tc_opts.axes_exist[tc_dim]]: # attempt to upcast M and N
|
||||
szs = [sz for sz in [5,4,3,2] if tk.full_shape[tc_opts.axes[tc_dim]] % sz == 0]
|
||||
if szs: tk.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[tc_dim], szs[0]))
|
||||
|
||||
if tc_opts.axes_exist[0] and (szs := [sz for sz in [4,2] if tk.full_shape[tc_opts.axes[0]] % sz == 0]): # attempt to local N
|
||||
tk.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], szs[0]))
|
||||
elif isinstance(k, Scheduler) and rngs is not None and not AMX:
|
||||
if rngs is not None and not AMX:
|
||||
for tc_dim in [1,0]: # attempt to upcast M and N
|
||||
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
|
||||
if szs:
|
||||
@@ -64,32 +53,17 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \
|
||||
(mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
if isinstance(k, Kernel):
|
||||
st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])]
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
||||
if strides0[first_reduce:=(k.axes_of(AxisType.REDUCE)[0])] == 1 and \
|
||||
not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
if k.full_shape[first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3:
|
||||
print(f"MATVEC: {k.full_shape=} {first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return k.applied_opts
|
||||
else:
|
||||
idx0, idx1 = mulop.src[0].src[0].src[1], mulop.src[1].src[0].src[1]
|
||||
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
|
||||
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3:
|
||||
print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return k.applied_opts
|
||||
idx0, idx1 = mulop.src[0].src[0].src[1], mulop.src[1].src[0].src[1]
|
||||
first_reduce_rng = k.ranges_of(AxisType.REDUCE)[0]
|
||||
if any(u is first_reduce_rng for u in idx0.split_uop(Ops.ADD)) and all(r in idx1.ranges for r in idx0.ranges):
|
||||
for global_idx in k.axes_of(AxisType.GLOBAL):
|
||||
if first_reduce_rng.src[0].divides(MV_THREADS_PER_ROW) is not None and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0:
|
||||
if DEBUG >= 3:
|
||||
print(f"MATVEC: {k.full_shape=} {first_reduce_rng.render()} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}")
|
||||
if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
|
||||
if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
|
||||
if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD))
|
||||
return k.applied_opts
|
||||
|
||||
# are we grouping? (requires local shape support)
|
||||
if resolve(prod(k.output_shape[i] for i in k.upcastable_dims) <= 2048, False):
|
||||
@@ -102,11 +76,8 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
# upcast float4 images
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
if isinstance(buf.src[0].dtype, ImageDType):
|
||||
if hasattr(k, "sts"):
|
||||
unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0]
|
||||
else:
|
||||
# part of real_strides
|
||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||
# part of real_strides
|
||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].split_uop(Ops.ADD) if c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||
if len(unit_stride_axes_mul_4):
|
||||
if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
|
||||
@@ -122,11 +93,9 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
to_upcast: list[int] = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in k.upcastable_dims:
|
||||
if isinstance(k, Kernel): is_masked = any(st.axis_is_masked(axis) for st in k.sts)
|
||||
else:
|
||||
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||||
is_masked = any(len(st.src) > 2 and k.rngs[axis] in st.src[2].parents for st in k.bufs) or \
|
||||
any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
|
||||
# for Schedule, we check if the range is used in INDEX gates or WHERE gates
|
||||
is_masked = any(len(st.src) > 2 and k.rngs[axis] in st.src[2].parents for st in k.bufs) or \
|
||||
any(any(o is k.rngs[axis] for o in u.src[0].parents) for u in k.ast.parents if u.op is Ops.WHERE)
|
||||
if k.full_shape[axis] <= 7 and is_masked and prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
@@ -141,24 +110,16 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
for axis, upcast_amount in itertools.product(k.upcastable_dims, ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||||
# if we haven't upcasted it, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis in upcasted_axis or k.full_shape[axis]%upcast_amount != 0: continue
|
||||
if isinstance(k, Kernel):
|
||||
# must have stride 0 on a view
|
||||
# must have all non stride 0 on what's upcasted before
|
||||
if any(st.views[-1].strides[axis] == 0 and \
|
||||
all(x != 0 for t,x in zip(k.axis_types, st.real_strides()) if t in (AxisType.UPCAST, AxisType.UNROLL)) for st in k.sts):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts),
|
||||
sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
|
||||
else:
|
||||
rng = k.rngs[axis]
|
||||
if any(rng not in b.src[1].parents and all(r2 in b.src[1].parents for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
|
||||
num_strides, sum_strides = 0, 0
|
||||
for b in k.bufs:
|
||||
if rng in b.src[1].parents: num_strides += 1
|
||||
for c in b.src[1].split_uop(Ops.ADD):
|
||||
if c is rng: sum_strides += 1
|
||||
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
|
||||
xb_choices.append((num_strides, sum_strides, axis, upcast_amount))
|
||||
rng = k.rngs[axis]
|
||||
if any(rng not in b.src[1].parents and all(r2 in b.src[1].parents for r2 in k.ranges_of(AxisType.UPCAST, AxisType.UNROLL)) for b in k.bufs):
|
||||
num_strides, sum_strides = 0, 0
|
||||
for b in k.bufs:
|
||||
if rng in b.src[1].parents: num_strides += 1
|
||||
for c in b.src[1].split_uop(Ops.ADD):
|
||||
if c is rng: sum_strides += 1
|
||||
if c.op is Ops.MUL and c.src[0] is rng and c.src[1].op is Ops.CONST: sum_strides += c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1] is rng and c.src[0].op is Ops.CONST: sum_strides += c.src[0].arg
|
||||
xb_choices.append((num_strides, sum_strides, axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
if DEBUG >= 4: print(f"more upcast axis : {xb_choices}")
|
||||
@@ -196,11 +157,8 @@ def hand_coded_optimizations(k:Kernel|Scheduler) -> list[Opt]:
|
||||
k.apply_opt(Opt(OptOps.NOLOCALS))
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
if isinstance(k, Kernel):
|
||||
local_axis_ranking = [(any(st.views[-1].strides[axis] == 0 for st in k.sts), axis) for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP)]
|
||||
else:
|
||||
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].parents for b in k.bufs), axis) \
|
||||
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
|
||||
local_axis_ranking = [(any(k.rngs[axis] not in b.src[1].parents for b in k.bufs), axis) \
|
||||
for axis in k.axes_of(AxisType.GLOBAL, AxisType.LOOP) if k.rngs[axis].src[0].op is Ops.CONST]
|
||||
to_local: list[tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
|
||||
@@ -1,435 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import itertools, functools, math
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final, Callable, Sequence
|
||||
from tinygrad.codegen.opt import OptOps, Opt, KernelOptError, check, axis_letters, axis_colors
|
||||
from tinygrad.uop.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, AxisType
|
||||
from tinygrad.uop.spec import type_verify, ast_spec
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_left_through_load
|
||||
|
||||
@dataclass
|
||||
class TensorCoreOptions:
|
||||
axes: tuple[int, ...] # the location of the original N and M axes if still in the shape
|
||||
axes_exist: tuple[bool, ...] # true if the original N and M axes are still in the shape
|
||||
axis_pads: tuple[tuple[int, int], ...]
|
||||
def fix_axes(self, removed_axis:int): # adjust the TC axes if necessary when a dimension is removed
|
||||
axes, axes_exist = list(self.axes), list(self.axes_exist)
|
||||
for tc_dim in [i for i in range(2) if axes_exist[i]]:
|
||||
if removed_axis < axes[tc_dim]: axes[tc_dim] -= 1
|
||||
elif removed_axis == axes[tc_dim]: axes_exist[tc_dim] = False
|
||||
self.axes, self.axes_exist = tuple(axes), tuple(axes_exist)
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:UOp, opts:Renderer|None=None):
|
||||
assert ast.op is Ops.SINK, ast.op
|
||||
self.ast = ast
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
# verify AST matches the spec
|
||||
if __debug__: type_verify(list(self.ast.toposort()), ast_spec)
|
||||
|
||||
self.vars: list[Variable] = self.ast.variables()
|
||||
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
||||
self.bufs: list[UOp] = [x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.st is not None][::-1]
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.sts: list[ShapeTracker] = [x.st_arg for x in self.bufs]
|
||||
|
||||
# add the shapetrackers for each reduce
|
||||
# we use this to track which axes are reduced in each reduce
|
||||
self.reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE_AXIS]
|
||||
for x in self.reduceops:
|
||||
self.sts.append(unwrap(x.st))
|
||||
self.sts.append(unwrap(x.src[0].st))
|
||||
|
||||
# add a shapetracker to the end to track the full shape, with 0 strides so it can merge
|
||||
full_shape = ast.full_shape
|
||||
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
|
||||
|
||||
# parameters for optimization
|
||||
self.tensor_core: TensorCore|None = None
|
||||
self.tensor_core_opts: TensorCoreOptions|None = None
|
||||
self.use_tensor_cores: int = 0
|
||||
self.applied_opts: list[Opt] = []
|
||||
self.dont_use_locals = False
|
||||
self.finalized: bool = False
|
||||
|
||||
# group simplifies
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
|
||||
# axis types
|
||||
global_loops = AxisType.GLOBAL if self.opts.has_local else AxisType.LOOP
|
||||
self.axis_types: list[AxisType] = [AxisType.REDUCE if resolve(x!=y) else global_loops for x,y in zip(self.output_shape, self.full_shape)]
|
||||
|
||||
# confirm all reduce axes are at the end
|
||||
if (final_reduces := [x for x in self.axis_types if x == AxisType.REDUCE]) and final_reduces != self.axis_types[-len(final_reduces):]:
|
||||
raise RuntimeError(f"reduces are not at the end of the shape {self.full_shape} -> {self.output_shape}")
|
||||
|
||||
def copy(self):
|
||||
ret = type(self).__new__(type(self))
|
||||
|
||||
# base linearizer params
|
||||
ret.opts, ret.ast = self.opts, self.ast
|
||||
|
||||
# things downstream of the AST
|
||||
ret.reduceops, ret.vars, ret.bufs = self.reduceops, self.vars, self.bufs
|
||||
ret.sts = self.sts[:]
|
||||
ret.axis_types = self.axis_types[:]
|
||||
|
||||
# parameters for optimizations
|
||||
ret.applied_opts, ret.dont_use_locals = self.applied_opts[:], self.dont_use_locals
|
||||
ret.tensor_core, ret.tensor_core_opts, ret.use_tensor_cores = self.tensor_core, self.tensor_core_opts, self.use_tensor_cores
|
||||
ret.finalized = self.finalized
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None
|
||||
@property
|
||||
def full_shape(self) -> tuple[sint, ...]: return self.sts[-1].shape
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape
|
||||
@property
|
||||
def shape_len(self) -> int: return len(self.full_shape)
|
||||
|
||||
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in argfix(axis_type)]
|
||||
@property
|
||||
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
|
||||
@property
|
||||
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
||||
|
||||
# heuristic helpers
|
||||
@property
|
||||
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
|
||||
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
||||
@property
|
||||
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
|
||||
if isinstance(s:=self.full_shape[i], int) and s > 1]
|
||||
|
||||
# ******************** colors and names ********************
|
||||
|
||||
def colors(self) -> list[str]:
|
||||
assert len(self.axis_types) == self.shape_len, "colors size mismatch"
|
||||
return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
|
||||
|
||||
def colored_shape(self, pad:int|None=None, dense=False) -> str:
|
||||
shape_strs = [(s if dense else f"{s:4d}") if isinstance(s, int) else s.render() for s in self.full_shape]
|
||||
ret = ' '.join(colored(s, color) for s,color in zip(shape_strs, self.colors()))
|
||||
if pad: ret += ' '*(pad-ansilen(ret))
|
||||
return ret
|
||||
|
||||
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# kernel name (before late upcast)
|
||||
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort()) else "E")
|
||||
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
||||
|
||||
# name the function something unique
|
||||
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
||||
num = f"n{Kernel.kernel_cnt[function_name]-1}" if Kernel.kernel_cnt[function_name] > 1 else ""
|
||||
return name + colored(num, 'BLACK')
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape(self, new_shape_fxn:Callable[[tuple[sint, ...]], Sequence[sint]]):
|
||||
self.sts = [st.reshape(tuple(new_shape_fxn(st.shape))) for st in self.sts]
|
||||
def permute(self, new_axes:Sequence[int]): self.sts = [st.permute(tuple(new_axes)) for st in self.sts]
|
||||
|
||||
# axis : the axis to pull from
|
||||
# amount : the amount to take
|
||||
# top : if you want to pull that amount from the top
|
||||
# insert_at : place to insert the new stuff
|
||||
def shift_to(self, axis:int, amount:int, new_type:AxisType, top:bool=False, insert_at:int|None=None) -> int:
|
||||
if insert_at is None: insert_at = self.shape_len
|
||||
self.axis_types.insert(insert_at, new_type)
|
||||
move_axis = axis if top else axis+1
|
||||
if move_axis < insert_at: insert_at += 1
|
||||
def new_shape_fxn(x): return x[0:axis] + (((amount,x[axis]//amount) if top else (x[axis]//amount,amount)) if x[axis] > 1 else (1,1)) + x[axis+1:]
|
||||
new_axes = [i for i in range(insert_at) if i != move_axis]+[move_axis]+[i for i in range(insert_at, self.shape_len+1) if i != move_axis]
|
||||
self.reshape(new_shape_fxn)
|
||||
self.permute(new_axes)
|
||||
return insert_at
|
||||
|
||||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self) -> bool:
|
||||
# remove places where the shape is all ones
|
||||
if any(all_ones:=[s==1 for s in self.full_shape]):
|
||||
if hasattr(self, 'axis_types'):
|
||||
self.axis_types = [x for i,x in enumerate(self.axis_types) if not all_ones[i]]
|
||||
self.reshape(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]])
|
||||
return True
|
||||
return False
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
assert not hasattr(self, 'axis_types'), "don't call this after init"
|
||||
if self.shape_len == 0: return
|
||||
shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts]
|
||||
# NOTE: we can't use self.first_reduce yet
|
||||
first_reduce = [resolve(x!=y) for x,y in zip(self.output_shape+(0,), self.full_shape+(1,))].index(True)
|
||||
|
||||
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
|
||||
# TODO: remove membufs
|
||||
membufs = dedup([x.src[0].base for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
if isinstance(membufs[0].base.dtype, ImageDType):
|
||||
base_shape = membufs[0].base.dtype.shape
|
||||
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
|
||||
special_strides: tuple[sint, ...] = tuple()
|
||||
for i,g in enumerate(shape_idx_groups):
|
||||
shape_piece = tuple(self.output_shape[x] for x in g)
|
||||
assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
|
||||
special_strides += strides_for_shape(shape_piece)
|
||||
# adding the fake image shape
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi _merge_dims
|
||||
# NOTE: this does not always preserve the reduce dimension
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
# TODO: how does this work with multi-reduce?
|
||||
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
|
||||
for i in range(1, len(shapes[0])):
|
||||
can_merge = []
|
||||
for s,st,ret in zip(shapes, strides, rets):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
si, sti, last_st = s[i], st[i], ret[-1][1]
|
||||
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != first_reduce
|
||||
for j,(s,st) in enumerate(zip(shapes, strides)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
||||
else: rets[j].append((s[i], st[i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** apply optimizations ********************
|
||||
|
||||
def real_axis(self, op:OptOps, axis:int|None):
|
||||
try:
|
||||
if axis is None: return -1
|
||||
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
|
||||
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
|
||||
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
|
||||
return axis
|
||||
except IndexError as e: raise KernelOptError from e
|
||||
|
||||
def apply_opt(self, opt:Opt, append_opt:bool=True) -> int|None:
|
||||
if self.finalized: raise RuntimeError("can't optimize Kernel after it's finalized")
|
||||
if self.dont_use_locals: check(opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}, "not using locals")
|
||||
|
||||
if opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
||||
check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
||||
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
|
||||
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
||||
self.applied_opts.append(opt)
|
||||
return None
|
||||
|
||||
axis = self.real_axis(opt.op, opt.axis)
|
||||
|
||||
if opt.op is OptOps.SWAP: amt = self.real_axis(opt.op, cast(int, opt.arg)) # arg is an axis in the SWAPs
|
||||
elif opt.arg is not None:
|
||||
check(isinstance(opt.arg, int), "arg should be int")
|
||||
amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis]
|
||||
check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless")
|
||||
if opt.op is not OptOps.PADTO:
|
||||
# we check both the full_shape and each shape
|
||||
check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}")
|
||||
for st in self.sts: check(st.shape[axis] == 1 or st.shape[axis] % amt == 0, f"no longer valid shift {st.shape[axis]=}, {amt=}")
|
||||
else: amt = -1
|
||||
|
||||
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
|
||||
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
||||
acc_sz = self.reduceop.dtype.itemsize
|
||||
upcast_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST)])
|
||||
local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.LOCAL)])
|
||||
smem_sz = amt*acc_sz*upcast_sz*local_sz
|
||||
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
||||
|
||||
new_axis = None
|
||||
if opt.op is OptOps.LOCAL: # cyan
|
||||
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
||||
# it's disabled for now since it makes BEAM slow for little gain
|
||||
check(self.opts.has_local, "target does not support local")
|
||||
check(self.axis_types[axis] is AxisType.GLOBAL, "local is for globals")
|
||||
new_axis = self.shift_to(axis, amt, AxisType.LOCAL, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL))+1)
|
||||
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
||||
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
||||
check(self.axis_types[axis] is AxisType.REDUCE, "must be reduce axis to group")
|
||||
check(not self.tensor_core, "can't group with tensor cores")
|
||||
check(len(reduce_axes:=[i for r in self.reduceops for i in r.axis_arg]) == len(set(reduce_axes)), "can't group with parallel reduces")
|
||||
new_axis = self.shift_to(axis, amt, AxisType.GROUP_REDUCE, top=(opt.op is OptOps.GROUPTOP), insert_at=min(self.axes_of(AxisType.REDUCE)))
|
||||
elif opt.op is OptOps.UNROLL: # purple
|
||||
check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "can't upcasted already upcasted")
|
||||
check(amt <= 32, "don't unroll more than 32")
|
||||
new_axis = self.shift_to(axis, amt, AxisType.UNROLL, insert_at=None)
|
||||
elif opt.op is OptOps.UPCAST: # yellow
|
||||
check(axis in self.upcastable_dims, f"{axis=} not in {self.upcastable_dims=}")
|
||||
# NOTE: assume the first get_local_axes() LOCAL are for TC
|
||||
check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals")
|
||||
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
||||
new_axis = self.shift_to(axis, amt, AxisType.UPCAST,
|
||||
insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1)
|
||||
elif opt.op is OptOps.NOLOCALS:
|
||||
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
||||
check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals")
|
||||
self.dont_use_locals = True
|
||||
elif opt.op is OptOps.SWAP:
|
||||
check(axis < amt, f"swap is only for axis < amt, getting {amt=}, {axis=}")
|
||||
check(self.axis_types[axis]==self.axis_types[amt]==AxisType.GLOBAL, f"swap is for globals {self.axis_types[axis]=}, {self.axis_types[amt]=}")
|
||||
permute = list(range(self.shape_len))
|
||||
permute[axis], permute[amt] = permute[amt], permute[axis]
|
||||
self.permute(tuple(permute))
|
||||
elif opt.op is OptOps.PADTO:
|
||||
check(not self.vars, "does not work with symbolic shape")
|
||||
check(self.axis_types[axis] not in (AxisType.UPCAST, AxisType.UNROLL), "cannot pad upcasted")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and self.axis_types[axis] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
|
||||
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if (s:=st.shape[axis]) == 1: continue # reduced
|
||||
check(s > amt//4, f"pad adds more than quadruple the work {st.shape[axis]=} > {amt//4=}")
|
||||
if (ru := round_up(cast(int, s), amt) - s):
|
||||
# pad right seems to be faster
|
||||
self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1))
|
||||
padded = True
|
||||
check(padded, "nothing was padded")
|
||||
|
||||
if append_opt: self.applied_opts.append(opt)
|
||||
if self.simplify_ones() and self.tensor_core_opts:
|
||||
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
||||
return new_axis
|
||||
|
||||
def apply_opts(self, opts:Sequence[Opt]) -> Kernel:
|
||||
for opt in opts: self.apply_opt(opt)
|
||||
return self
|
||||
|
||||
# **** kernel outputs, mostly tensor cores ****
|
||||
|
||||
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> TensorCoreOptions|None:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.op is not Ops.MUL: return None
|
||||
|
||||
def buf_index(src:UOp) -> int|None:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
|
||||
try:
|
||||
if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
||||
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i in self.upcastable_dims if buf0_strides[i] == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i in self.upcastable_dims if buf1_strides[i] == 0]
|
||||
if not (axis_buf0 and axis_buf1 and (len(self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (opt_level >= 1))): return None
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)))
|
||||
if not (axis < len(axis_choices)): return None
|
||||
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
axis_pads = tuple((x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if resolve(self.full_shape[x]%tc.dims[i] != 0))
|
||||
if axis_pads and (opt_level < 2): return None
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads)
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool:
|
||||
if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD:
|
||||
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
|
||||
for tc in tensor_cores:
|
||||
tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops]
|
||||
if tensor_core_opts[0] is None: continue
|
||||
# can only fuse reduces with the same tc options
|
||||
assert all_same(tensor_core_opts)
|
||||
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
||||
|
||||
# attempt to pad the tensor axes that require it
|
||||
try:
|
||||
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
||||
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
||||
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, 0, amt), append_opt=False) # TODO: this should be the reduce, not 0
|
||||
self.tensor_core = tc
|
||||
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
||||
return True
|
||||
return False
|
||||
|
||||
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
||||
def shape_str(self) -> list[str]:
|
||||
ret: list[str] = []
|
||||
cnt: dict[AxisType, int] = {}
|
||||
for x in self.axis_types:
|
||||
cnt[x] = (cnt[x] + 1) if x in cnt else 0
|
||||
ret.append(f"{axis_letters[x]}{cnt[x]}")
|
||||
return ret
|
||||
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
|
||||
|
||||
def get_optimized_ast(self, name_override:str|None=None) -> UOp:
|
||||
@functools.cache
|
||||
def fixup_ast(op:UOp) -> UOp:
|
||||
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
|
||||
if op.op in GroupOp.Buffer and op in self.bufs:
|
||||
st = self.sts[self.bufs.index(op)]
|
||||
# replace the VIEW source
|
||||
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
||||
if op.op is Ops.SINK:
|
||||
# NOTE: should group_for_reduces be added to the local_dims?
|
||||
# TODO: arg.name should be able to be None
|
||||
kernel_name = ret.arg.name if ret.arg is not None and ret.arg.name != "test" else self.name if name_override is None else name_override
|
||||
return ret.replace(arg=KernelInfo(kernel_name, tuple(self.axis_types), self.dont_use_locals, tuple(self.applied_opts)))
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
||||
changed = tuple(i for i in range(self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
||||
axes = tuple(i for i in self.axes_of(AxisType.REDUCE, AxisType.GROUP_REDUCE, AxisType.UNROLL) if i in changed)
|
||||
if (tc := self.tensor_core) and self.use_tensor_cores == 1:
|
||||
# get reduce/upcast axes for the tensor cores
|
||||
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
|
||||
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
|
||||
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
|
||||
|
||||
# permute the srcs
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
for i, (src, permaxis) in enumerate(zip(srcs, tc.permutes_for_shape_str(self.shape_str()))):
|
||||
src_st = (src if src.op is Ops.LOAD else src.src[0]).st_arg
|
||||
srcs[i] = src.view(ShapeTracker.from_shape(src_st.shape).permute(permaxis))
|
||||
|
||||
# construct the op
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
||||
|
||||
# preserve any other reduce
|
||||
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
||||
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
return ret
|
||||
self.finalized = True
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
del fixup_ast
|
||||
return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")
|
||||
@@ -12,8 +12,6 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
|
||||
# both versions
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
|
||||
@@ -63,7 +61,7 @@ def timeout_handler(signum, frame):
|
||||
if DEBUG >= 2: print("*** BEAM COMPILE TIMEOUT")
|
||||
raise TimeoutException()
|
||||
|
||||
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
|
||||
def _try_compile_linearized_w_idx(x:tuple[int,Scheduler], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
|
||||
if hasattr(signal, "alarm"):
|
||||
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
||||
# set timeout
|
||||
@@ -98,7 +96,7 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
# NOTE: there's also bufs_from_ast in postrange
|
||||
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
||||
def bufs_from_lin(lin:Scheduler, allocate:bool=True) -> list[Buffer]:
|
||||
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
||||
for x in lin.bufs:
|
||||
if x.src[0].base.op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].base.arg].append(x)
|
||||
@@ -115,7 +113,7 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
||||
return cast(list[Buffer], rawbufs)
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_kernel_actions(lin:Kernel|Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Kernel|Scheduler]:
|
||||
def get_kernel_actions(lin:Scheduler, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Scheduler]:
|
||||
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
||||
kernel_actions = (actions if candidates is None else candidates).copy()
|
||||
|
||||
@@ -139,7 +137,7 @@ def get_kernel_actions(lin:Kernel|Scheduler, include_0=True, candidates:list[Opt
|
||||
return acted_lins
|
||||
|
||||
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
||||
def beam_search(lin:Kernel|Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value):
|
||||
def beam_search(lin:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value):
|
||||
global beam_pool
|
||||
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
||||
@@ -147,7 +145,7 @@ def beam_search(lin:Kernel|Scheduler, rawbufs:list[Buffer], amt:int, allow_test_
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
beam: list[tuple[Kernel|Scheduler, float]] = [(lin, float("inf"))]
|
||||
beam: list[tuple[Scheduler, float]] = [(lin, float("inf"))]
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
|
||||
@@ -168,8 +166,8 @@ def beam_search(lin:Kernel|Scheduler, rawbufs:list[Buffer], amt:int, allow_test_
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
acted_lins: list[Kernel|Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
timed_lins: list[tuple[Kernel|Scheduler, float]] = []
|
||||
acted_lins: list[Scheduler] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
timed_lins: list[tuple[Scheduler, float]] = []
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
least_compute_ops = math.inf
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.opt.kernel import Opt
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, printable, GroupOp,
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.opt.kernel import axis_colors
|
||||
from tinygrad.codegen.opt import axis_colors
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
|
||||
Reference in New Issue
Block a user