mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from tinygrad.helpers import BEAM, Timing, CI, prod
|
||||
from tinygrad import Variable, Device, Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.uop.ops import AxisType
|
||||
from tinygrad.uop.ops import AxisType, Ops
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt.search import get_kernel_actions
|
||||
@@ -85,6 +85,7 @@ class TestBeamSearch(unittest.TestCase):
|
||||
size = max(tc.dims[0], tc.dims[1]) * 8
|
||||
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
|
||||
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
|
||||
if ast.op is Ops.BEAM: ast = ast.src[0]
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
|
||||
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
|
||||
@@ -95,6 +96,7 @@ class TestBeamSearch(unittest.TestCase):
|
||||
def test_max_up(self):
|
||||
a = Tensor.rand(16, 16)
|
||||
ast = a.schedule()[-1].ast
|
||||
if ast.op is Ops.BEAM: ast = ast.src[0]
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
for max_up in (2, 4):
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=max_up)
|
||||
|
||||
@@ -21,7 +21,7 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.renderer.amd.elf import do_assemble_amd
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True, beam:int=0) -> UOp:
|
||||
if ren is None: ren = Renderer(Target())
|
||||
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Base AST")
|
||||
@@ -46,7 +46,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
sink = graph_rewrite(sink, pm_flatten_range+pm_simplify_ranges, ctx={}, name="simplify ranges")
|
||||
|
||||
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||
sink = apply_opts(sink, ren)
|
||||
sink = apply_opts(sink, ren, beam=beam)
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
||||
@@ -164,14 +164,15 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
||||
"""
|
||||
|
||||
if ast.op is Ops.PROGRAM: prg = ast
|
||||
elif ast.op is Ops.SINK:
|
||||
elif ast.op is Ops.SINK or ast.op is Ops.BEAM:
|
||||
beam, ast = (ast.arg, ast.src[0]) if ast.op is Ops.BEAM else (0, ast)
|
||||
# rewrite to prg
|
||||
assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to get_program"
|
||||
if opts is not None:
|
||||
# TODO: should this be here?
|
||||
assert ast.arg.opts_to_apply is None, "can't apply opts if there's already opts to apply"
|
||||
ast = ast.replace(arg=replace(ast.arg, opts_to_apply=tuple(opts)))
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None, beam=beam)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)))
|
||||
else:
|
||||
raise RuntimeError(f"can't call get_program on {ast.op}")
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_r
|
||||
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import ALLOW_TF32, count, Context
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
@@ -334,18 +334,18 @@ def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
|
||||
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg)
|
||||
return [Buffer(dname, x.ptrdtype.size, x.dtype.base) for x in glbls]
|
||||
|
||||
def apply_opts(ast:UOp, ren:Renderer) -> UOp:
|
||||
def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:
|
||||
if ast.tag is not None: return ast
|
||||
k = Scheduler(ast, ren)
|
||||
k.convert_loop_to_global()
|
||||
if ast.arg is not None and ast.arg.opts_to_apply is not None:
|
||||
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
|
||||
elif BEAM >= 1:
|
||||
elif beam >= 1:
|
||||
from tinygrad.codegen.opt.search import beam_search
|
||||
rawbufs = bufs_from_ast(ast, ren.target.device)
|
||||
# beam search may open devices
|
||||
with Context(ALLOW_DEVICE_USAGE=1):
|
||||
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
k = beam_search(k, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import cast, Callable
|
||||
import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context, unwrap
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, unwrap
|
||||
from tinygrad.helpers import EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
@@ -130,7 +130,7 @@ class EncDec(Runner):
|
||||
method_cache: dict[tuple[str, type, bytes, tuple, bool], CompiledRunner] = {}
|
||||
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
# TODO: this should be all context relevant to rendering
|
||||
context = (BEAM.value, NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
|
||||
context = (NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
|
||||
ckey = (device, type(Device[device].compiler), ast.key, context, False)
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
@@ -145,7 +145,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat((Ops.SINK, Ops.PROGRAM, Ops.BEAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||
(UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
@@ -198,7 +198,8 @@ capturing: list = [] # put classes with an add_linear method in here
|
||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
while len(schedule):
|
||||
ei = schedule.pop(0).lower()
|
||||
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
|
||||
sink = ei.ast.src[0] if ei.ast.op is Ops.BEAM else ei.ast
|
||||
if VALIDATE_WITH_CPU and sink.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
bufs = [b for b in ei.bufs if b is not None]
|
||||
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
||||
@@ -209,7 +210,7 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
|
||||
ExecItem(sink, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
assert nb[0] is not None
|
||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
@@ -72,6 +72,8 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
||||
base = buf_uops[1].buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
# wrap SINK with BEAM UOp when beam search is enabled
|
||||
if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value)
|
||||
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
|
||||
metadata = si.arg.metadata
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
|
||||
|
||||
@@ -34,7 +34,7 @@ class Ops(FastEnum):
|
||||
|
||||
# AFTER passes src[0] through and promises in the toposort that any consumers of the AFTER run after src[1:]
|
||||
# GROUP is a NOOP that just merges things together
|
||||
SINK = auto(); AFTER = auto(); GROUP = auto()
|
||||
SINK = auto(); AFTER = auto(); GROUP = auto(); BEAM = auto()
|
||||
|
||||
# vector creation / item selection
|
||||
GEP = auto(); VECTORIZE = auto()
|
||||
|
||||
@@ -298,6 +298,8 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.floats), lambda: True),
|
||||
# allow any AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
# BEAM wraps a SINK for beam search
|
||||
(UPat(Ops.BEAM, src=(UPat(Ops.SINK),)), lambda: True),
|
||||
])+_tensor_spec+kernel_spec+program_spec+shared_spec
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
Reference in New Issue
Block a user