diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index 36aba141b6..b1c397f742 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.helpers import BEAM, Timing, CI, prod from tinygrad import Variable, Device, Tensor from tinygrad.nn import Conv2d -from tinygrad.uop.ops import AxisType +from tinygrad.uop.ops import AxisType, Ops from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.opt.postrange import Scheduler from tinygrad.codegen.opt.search import get_kernel_actions @@ -85,6 +85,7 @@ class TestBeamSearch(unittest.TestCase): size = max(tc.dims[0], tc.dims[1]) * 8 a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in) ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast + if ast.op is Ops.BEAM: ast = ast.src[0] s = Scheduler(ast, Device[Device.DEFAULT].renderer) s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1))) up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)]) @@ -95,6 +96,7 @@ class TestBeamSearch(unittest.TestCase): def test_max_up(self): a = Tensor.rand(16, 16) ast = a.schedule()[-1].ast + if ast.op is Ops.BEAM: ast = ast.src[0] s = Scheduler(ast, Device[Device.DEFAULT].renderer) for max_up in (2, 4): actions = get_kernel_actions(s, include_0=False, max_up=max_up) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index db3026fde5..6381c76b46 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -21,7 +21,7 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.renderer.amd.elf import do_assemble_amd -def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp: +def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, beam:int=0) -> UOp: if ren is None: ren = Renderer(Target()) if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Base AST") @@ -46,7 +46,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - sink = graph_rewrite(sink, pm_flatten_range+pm_simplify_ranges, ctx={}, name="simplify ranges") # do postrange optimization, BEAM or hand_coded_optimizations - sink = apply_opts(sink, ren) + sink = apply_opts(sink, ren, beam=beam) # ** expander (expand_rewrite) ** sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic") @@ -164,14 +164,15 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program """ if ast.op is Ops.PROGRAM: prg = ast - elif ast.op is Ops.SINK: + elif ast.op is Ops.SINK or ast.op is Ops.BEAM: + beam, ast = (ast.arg, ast.src[0]) if ast.op is Ops.BEAM else (0, ast) # rewrite to prg assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to get_program" if opts is not None: # TODO: should this be here? assert ast.arg.opts_to_apply is None, "can't apply opts if there's already opts to apply" ast = ast.replace(arg=replace(ast.arg, opts_to_apply=tuple(opts))) - full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None) + full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None, beam=beam) prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device))) else: raise RuntimeError(f"can't call get_program on {ast.op}") diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 0d21831d6f..dfffc27dbb 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -6,7 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_r from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes -from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten +from tinygrad.helpers import colored, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten from tinygrad.helpers import ALLOW_TF32, count, Context from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.simplify import pm_flatten_range @@ -334,18 +334,18 @@ def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]: glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg) return [Buffer(dname, x.ptrdtype.size, x.dtype.base) for x in glbls] -def apply_opts(ast:UOp, ren:Renderer) -> UOp: +def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp: if ast.tag is not None: return ast k = Scheduler(ast, ren) k.convert_loop_to_global() if ast.arg is not None and ast.arg.opts_to_apply is not None: for opt in ast.arg.opts_to_apply: k.apply_opt(opt) - elif BEAM >= 1: + elif beam >= 1: from tinygrad.codegen.opt.search import beam_search rawbufs = bufs_from_ast(ast, ren.target.device) # beam search may open devices with Context(ALLOW_DEVICE_USAGE=1): - k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) + k = beam_search(k, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1))) elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index a2d387a0e3..06e0bb3329 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,8 +1,8 @@ from typing import cast, Callable import time, pprint, random, itertools, math from dataclasses import dataclass, replace, field -from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, Metadata, TRACEMETA, TracingKey -from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap +from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey +from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, unwrap from tinygrad.helpers import EMULATED_DTYPES from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer from tinygrad.device import Device, Buffer @@ -130,7 +130,7 @@ class EncDec(Runner): method_cache: dict[tuple[str, type, bytes, tuple, bool], CompiledRunner] = {} def get_runner(device:str, ast:UOp) -> CompiledRunner: # TODO: this should be all context relevant to rendering - context = (BEAM.value, NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value) + context = (NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value) ckey = (device, type(Device[device].compiler), ast.key, context, False) if cret:=method_cache.get(ckey): return cret bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) @@ -145,7 +145,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: # NOTE: ctx is the buffers si_lowerer = PatternMatcher([ - (UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)), + (UPat((Ops.SINK, Ops.PROGRAM, Ops.BEAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)), (UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])), (UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \ @@ -198,7 +198,8 @@ capturing: list = [] # put classes with an add_linear method in here def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True): while len(schedule): ei = schedule.pop(0).lower() - if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK: + sink = ei.ast.src[0] if ei.ast.op is Ops.BEAM else ei.ast + if VALIDATE_WITH_CPU and sink.op is Ops.SINK: # copy in allocated buffers from the GPU bufs = [b for b in ei.bufs if b is not None] nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs] @@ -209,7 +210,7 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_ ei.run(var_vals, do_update_stats=do_update_stats) # validate the output buffers match (NOTE: this is assuming the output is buffer 0) - with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats) + ExecItem(sink, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats) import numpy as np assert nb[0] is not None np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4b553c7a77..cf9261f337 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from collections import deque from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM from tinygrad.engine.realize import ExecItem # **** schedule linearizer @@ -72,6 +72,8 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]: base = buf_uops[1].buffer assert isinstance(base, Buffer), "base can't be MultiBuffer" buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize) + # wrap SINK with BEAM UOp when beam search is enabled + if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value) ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND] metadata = si.arg.metadata if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index a67a2f8e98..aaabfe3756 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -34,7 +34,7 @@ class Ops(FastEnum): # AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:] # GROUP is a NOOP that just merges things together - SINK = auto(); AFTER = auto(); GROUP = auto() + SINK = auto(); AFTER = auto(); GROUP = auto(); BEAM = auto() # vector creation / item selection GEP = auto(); VECTORIZE = auto() diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index ffde0dd57a..b9450f0f18 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -298,6 +298,8 @@ full_spec = PatternMatcher([ (UPat(Ops.DEFINE_VAR, dtype=dtypes.floats), lambda: True), # allow any AFTER (UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True), + # BEAM wraps a SINK for beam search + (UPat(Ops.BEAM, src=(UPat(Ops.SINK),)), lambda: True), ])+_tensor_spec+kernel_spec+program_spec+shared_spec # ***** uop helpers *****