* beam as uop

* x
This commit is contained in:
nimlgen
2026-04-09 19:13:03 +03:00
committed by GitHub
parent 0ff30b003d
commit 057dc173ab
7 changed files with 25 additions and 17 deletions

View File

@@ -4,7 +4,7 @@ import numpy as np
from tinygrad.helpers import BEAM, Timing, CI, prod
from tinygrad import Variable, Device, Tensor
from tinygrad.nn import Conv2d
from tinygrad.uop.ops import AxisType
from tinygrad.uop.ops import AxisType, Ops
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.opt.postrange import Scheduler
from tinygrad.codegen.opt.search import get_kernel_actions
@@ -85,6 +85,7 @@ class TestBeamSearch(unittest.TestCase):
size = max(tc.dims[0], tc.dims[1]) * 8
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
if ast.op is Ops.BEAM: ast = ast.src[0]
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
@@ -95,6 +96,7 @@ class TestBeamSearch(unittest.TestCase):
def test_max_up(self):
a = Tensor.rand(16, 16)
ast = a.schedule()[-1].ast
if ast.op is Ops.BEAM: ast = ast.src[0]
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
for max_up in (2, 4):
actions = get_kernel_actions(s, include_0=False, max_up=max_up)

View File

@@ -21,7 +21,7 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.renderer.amd.elf import do_assemble_amd
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, beam:int=0) -> UOp:
if ren is None: ren = Renderer(Target())
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
@@ -46,7 +46,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = graph_rewrite(sink, pm_flatten_range+pm_simplify_ranges, ctx={}, name="simplify ranges")
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren)
sink = apply_opts(sink, ren, beam=beam)
# ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
@@ -164,14 +164,15 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
"""
if ast.op is Ops.PROGRAM: prg = ast
elif ast.op is Ops.SINK:
elif ast.op is Ops.SINK or ast.op is Ops.BEAM:
beam, ast = (ast.arg, ast.src[0]) if ast.op is Ops.BEAM else (0, ast)
# rewrite to prg
assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to get_program"
if opts is not None:
# TODO: should this be here?
assert ast.arg.opts_to_apply is None, "can't apply opts if there's already opts to apply"
ast = ast.replace(arg=replace(ast.arg, opts_to_apply=tuple(opts)))
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None, beam=beam)
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)))
else:
raise RuntimeError(f"can't call get_program on {ast.op}")

View File

@@ -6,7 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_r
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import colored, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import ALLOW_TF32, count, Context
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range
@@ -334,18 +334,18 @@ def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg)
return [Buffer(dname, x.ptrdtype.size, x.dtype.base) for x in glbls]
def apply_opts(ast:UOp, ren:Renderer) -> UOp:
def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:
if ast.tag is not None: return ast
k = Scheduler(ast, ren)
k.convert_loop_to_global()
if ast.arg is not None and ast.arg.opts_to_apply is not None:
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
elif BEAM >= 1:
elif beam >= 1:
from tinygrad.codegen.opt.search import beam_search
rawbufs = bufs_from_ast(ast, ren.target.device)
# beam search may open devices
with Context(ALLOW_DEVICE_USAGE=1):
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
k = beam_search(k, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet

View File

@@ -1,8 +1,8 @@
from typing import cast, Callable
import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, unwrap
from tinygrad.helpers import EMULATED_DTYPES
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer
from tinygrad.device import Device, Buffer
@@ -130,7 +130,7 @@ class EncDec(Runner):
method_cache: dict[tuple[str, type, bytes, tuple, bool], CompiledRunner] = {}
def get_runner(device:str, ast:UOp) -> CompiledRunner:
# TODO: this should be all context relevant to rendering
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
context = (NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
ckey = (device, type(Device[device].compiler), ast.key, context, False)
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
@@ -145,7 +145,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat((Ops.SINK, Ops.PROGRAM, Ops.BEAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
@@ -198,7 +198,8 @@ capturing: list = [] # put classes with an add_linear method in here
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
while len(schedule):
ei = schedule.pop(0).lower()
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
sink = ei.ast.src[0] if ei.ast.op is Ops.BEAM else ei.ast
if VALIDATE_WITH_CPU and sink.op is Ops.SINK:
# copy in allocated buffers from the GPU
bufs = [b for b in ei.bufs if b is not None]
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
@@ -209,7 +210,7 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
ei.run(var_vals, do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
ExecItem(sink, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
import numpy as np
assert nb[0] is not None
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)

View File

@@ -4,7 +4,7 @@ from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
@@ -72,6 +72,8 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
base = buf_uops[1].buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
# wrap SINK with BEAM UOp when beam search is enabled
if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value)
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
metadata = si.arg.metadata
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":

View File

@@ -34,7 +34,7 @@ class Ops(FastEnum):
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
# GROUP is a NOOP that just merges things together
SINK = auto(); AFTER = auto(); GROUP = auto()
SINK = auto(); AFTER = auto(); GROUP = auto(); BEAM = auto()
# vector creation / item selection
GEP = auto(); VECTORIZE = auto()

View File

@@ -298,6 +298,8 @@ full_spec = PatternMatcher([
(UPat(Ops.DEFINE_VAR, dtype=dtypes.floats), lambda: True),
# allow any AFTER
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
# BEAM wraps a SINK for beam search
(UPat(Ops.BEAM, src=(UPat(Ops.SINK),)), lambda: True),
])+_tensor_spec+kernel_spec+program_spec+shared_spec
# ***** uop helpers *****