From 77965a22e5781748d1771787ee3ab31f3b14a990 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 28 Apr 2026 22:51:04 +0300 Subject: [PATCH] local optimize as rewrite (#15953) * local optimize as rewrite * better * x * slighly rename * fix * ugh * remove * x * remove * not weak --- .github/workflows/benchmark.yml | 4 ++ examples/openpilot/compile3.py | 18 ++++-- extra/gemm/triton_nv_matmul.py | 4 +- test/backend/test_linearizer.py | 3 +- test/backend/test_renderer_failures.py | 4 +- test/backend/test_uops.py | 3 +- test/device/test_hcq.py | 3 +- test/external/external_benchmark_op_conv.py | 2 +- test/external/external_uop_gc.py | 3 +- test/opt/test_tensor_cores.py | 5 +- tinygrad/codegen/__init__.py | 18 +++--- tinygrad/codegen/opt/search.py | 4 +- tinygrad/engine/jit.py | 2 +- tinygrad/engine/realize.py | 64 ++++++++++++--------- tinygrad/runtime/graph/hcq.py | 2 +- tinygrad/uop/ops.py | 9 ++- 16 files changed, 82 insertions(+), 66 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 9277b16886..ffb67b7d5e 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -623,6 +623,8 @@ jobs: run: test/external/process_replay/reset.py - name: openpilot compile3 0.11.0 driving_vision run: BENCHMARK_LOG=openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx + - name: openpilot compile3 0.11.0 driving_vision (from pickle) + run: BENCHMARK_LOG=openpilot_0_11_0_vision_run_pickle RUN_PICKLE=1 PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM taskset -c 4-7 python3 examples/openpilot/compile3.py - name: IR3 openpilot compile3 0.11.0 driving_vision run: BENCHMARK_LOG=ir3_openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM:IR3 FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.11.0 driving_policy @@ -668,6 +670,8 @@ jobs: run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision PYTHONPATH="." GMMU=0 DEV=USB+AMD:LLVM ASSERT_MIN_STEP_TIME=50 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: openpilot load_pickle 0.10.1 driving_vision run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision_load_pickle PYTHONPATH="." GMMU=0 DEV=USB+AMD ASSERT_MIN_LOAD_TIME=15 python3 examples/openpilot/load_pickle.py + - name: openpilot run_pickle 0.10.1 driving_vision + run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision_run_pickle RUN_PICKLE=1 PYTHONPATH="." GMMU=0 DEV=USB+AMD ASSERT_MIN_STEP_TIME=50 python3 examples/openpilot/compile3.py testreddriverbenchmark: name: AM Benchmark diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index e0328704f3..c8a4502a8f 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -133,14 +133,20 @@ def bench(run, inputs): run(**inputs).numpy() if __name__ == "__main__": - onnx_file = fetch(OPENPILOT_MODEL) - inputs, outputs = compile(onnx_file) + if getenv("RUN_PICKLE"): + with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f) + inputs = {name: Tensor(Tensor.randn(*[int(s) for s in view.src[1].arg], dtype=dtype).numpy(), device=device) + for name, (view, _vars, dtype, device) in zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_input_info)} + test_vs_compile(pickle_loaded, inputs) + else: + onnx_file = fetch(OPENPILOT_MODEL) + inputs, outputs = compile(onnx_file) - with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f) + with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f) - test_vs_compile(pickle_loaded, inputs, outputs) - if getenv("SELFTEST"): - test_vs_onnx(inputs, outputs, onnx_file, 1e-4) + test_vs_compile(pickle_loaded, inputs, outputs) + if getenv("SELFTEST"): + test_vs_onnx(inputs, outputs, onnx_file, 1e-4) if getenv("BENCHMARK_LOG", ""): bench(pickle_loaded, inputs) diff --git a/extra/gemm/triton_nv_matmul.py b/extra/gemm/triton_nv_matmul.py index 678e82e7e4..60fe9d5c82 100644 --- a/extra/gemm/triton_nv_matmul.py +++ b/extra/gemm/triton_nv_matmul.py @@ -89,11 +89,11 @@ if __name__ == "__main__": # remove debug sections src = src.split("\t.file")[0] assert '.extern .shared' not in src - info = ProgramInfo(name="matmul_kernel", device=Device.DEFAULT, + info = ProgramInfo(name="matmul_kernel", global_size=(M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1), local_size=(32*compiled.metadata.num_warps, 1, 1)) sink = UOp.sink(arg=KernelInfo(name="matmul_kernel")) prg_uop = UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR), UOp(Ops.SOURCE, arg=src)), arg=info) - runner = CompiledRunner(prg_uop) + runner = CompiledRunner(prg_uop, Device.DEFAULT) all_bufs = [x.ensure_allocated() for x in bufs] prg_bufs = [all_bufs[i] for i in runner.p.globals] tflops = [] diff --git a/test/backend/test_linearizer.py b/test/backend/test_linearizer.py index 34f8b95a1b..9529460e75 100644 --- a/test/backend/test_linearizer.py +++ b/test/backend/test_linearizer.py @@ -1,6 +1,5 @@ import numpy as np import unittest -from dataclasses import replace from tinygrad.codegen.opt import Opt, OptOps from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType @@ -430,7 +429,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[] def get_prg(opts): ast = realized_ast if opts is None else replace_opts(realized_ast, list(opts)) - return CompiledRunner((pu:=to_program(ast, renderer=Device[Device.DEFAULT].renderer)).replace(arg=replace(pu.arg, device=device))) + return CompiledRunner(to_program(ast, renderer=Device[Device.DEFAULT].renderer), device) def check_opt(opts): prg = get_prg(opts=opts) diff --git a/test/backend/test_renderer_failures.py b/test/backend/test_renderer_failures.py index bf4a2f3dda..11abf6c603 100644 --- a/test/backend/test_renderer_failures.py +++ b/test/backend/test_renderer_failures.py @@ -19,9 +19,9 @@ def _test_uop_result(inputs:list[Tensor], prg:UOp, local_size=None): outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \ initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] inbufs = [x.uop.base.buffer for x in inputs] - info = replace(prg.arg, device=Device.DEFAULT) + info = prg.arg if local_size is not None: info = replace(info, local_size=tuple(local_size)) - ei = CompiledRunner(prg.replace(arg=info)) + ei = CompiledRunner(prg.replace(arg=info), Device.DEFAULT) ei.exec(outbufs+inbufs) return [np.frombuffer(x.as_memoryview(), _to_np_dtype(x.dtype)) for x in outbufs] diff --git a/test/backend/test_uops.py b/test/backend/test_uops.py index 653256a187..40fd9cf6a1 100644 --- a/test/backend/test_uops.py +++ b/test/backend/test_uops.py @@ -13,11 +13,10 @@ from tinygrad.device import is_dtype_supported from tinygrad.codegen.opt import Opt, OptOps from tinygrad.renderer.ptx import PTXRenderer from test.helpers import to_uops_list -from dataclasses import replace def _uops_to_prg(uops_list): prg = to_program(UOp.sink(*uops_list, arg=KernelInfo()), Device[Device.DEFAULT].renderer) - return CompiledRunner(prg.replace(arg=replace(prg.arg, device=Device.DEFAULT))) + return CompiledRunner(prg, Device.DEFAULT) def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: if op is Ops.CONST: uops.append(UOp.const(dtype, arg)) diff --git a/test/device/test_hcq.py b/test/device/test_hcq.py index b43df7869d..edba1d6ea2 100644 --- a/test/device/test_hcq.py +++ b/test/device/test_hcq.py @@ -166,7 +166,8 @@ class TestHCQ(unittest.TestCase): b = a + 1 si = b.schedule_linear().src[-1] - runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer)) + runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer), + Device.DEFAULT) zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated() zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated() diff --git a/test/external/external_benchmark_op_conv.py b/test/external/external_benchmark_op_conv.py index 00f9fb8b9c..edaa47448b 100644 --- a/test/external/external_benchmark_op_conv.py +++ b/test/external/external_benchmark_op_conv.py @@ -90,7 +90,7 @@ renderer = Device.default.renderer allocator = Device.default.allocator ps = to_program(ast, renderer) -cr = CompiledRunner(ps.replace(arg=replace(ps.arg, device=Device.DEFAULT))) +cr = CompiledRunner(ps, Device.DEFAULT) gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.PARAM]), key=lambda u: u.arg) # print(len(gs)) diff --git a/test/external/external_uop_gc.py b/test/external/external_uop_gc.py index 91b7e5ef0d..3077009ae8 100644 --- a/test/external/external_uop_gc.py +++ b/test/external/external_uop_gc.py @@ -2,7 +2,7 @@ import gc from tinygrad import Tensor, UOp, Device, nn from tinygrad.schedule import schedule_cache from tinygrad.engine.realize import method_cache -from tinygrad.codegen import to_program +from tinygrad.codegen import to_program, to_program_cache from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape from tinygrad.uop.divandmod import fold_divmod_general from test.test_tiny import TestTiny @@ -72,6 +72,7 @@ if __name__ == "__main__": # these caches will keep uops alive schedule_cache.clear() method_cache.clear() + to_program_cache.clear() apply_movement_op.cache_clear() _apply_reshape.cache_clear() fold_divmod_general.cache_clear() diff --git a/test/opt/test_tensor_cores.py b/test/opt/test_tensor_cores.py index fc7d3d1397..669afbbfd9 100644 --- a/test/opt/test_tensor_cores.py +++ b/test/opt/test_tensor_cores.py @@ -1,6 +1,5 @@ import numpy as np import unittest -from dataclasses import replace from tinygrad import Device, Tensor, dtypes from tinygrad.tensor import _to_np_dtype @@ -51,7 +50,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi pu = to_program(replace_opts(realized_ast, opts), Device[Device.DEFAULT].renderer) if use_tensor_cores == 1: assert len([uop for uop in pu.src[2].src if uop.op is Ops.WMMA]) > 0, "wmma not triggered" assert len([x for x in pu.src[0].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" - prg = CompiledRunner(pu.replace(arg=replace(pu.arg, device=Device.DEFAULT))) + prg = CompiledRunner(pu, Device.DEFAULT) prg.exec(bufs) if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = (1e-1, 2e-2) if dtype_out == dtypes.bfloat16 else (1e-2, 1e-2) @@ -150,7 +149,7 @@ class TestTensorCores(unittest.TestCase): assert len([uop for uop in tuple(program.src[2].src) if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" assert len([x for x in program.src[0].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" - prg = CompiledRunner(program) + prg = CompiledRunner(program, Device.DEFAULT) # TODO: support this even if numpy doesn't if _to_np_dtype(real_bufs[0].dtype) is None: continue real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 58387b587b..eb6cfce057 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,8 +1,8 @@ from typing import cast from dataclasses import replace -import itertools, weakref -from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES -from tinygrad.helpers import TracingKey, Context, Target, panic +import itertools +from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC +from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, Target, panic from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, pyrender from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer, Estimates @@ -171,18 +171,18 @@ def do_to_program(ast:UOp, renderer:Renderer) -> UOp: elif ast.op is Ops.SINK: assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to to_program" full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None, beam=ast.arg.beam) - prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)), - arg=ProgramInfo.from_sink(full_sink, renderer.target.device)) + prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)), arg=ProgramInfo.from_sink(full_sink)) else: raise RuntimeError(f"can't call to_program on {ast.op}") - if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0], prg.src[1].arg)) + if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0])) prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render") if VIZ: graph_rewrite(prg, PatternMatcher([]), name="View Program") return prg -to_program_cache: weakref.WeakValueDictionary[tuple, UOp] = weakref.WeakValueDictionary() +to_program_cache: dict[tuple, UOp] = {} def to_program(ast:UOp, renderer:Renderer) -> UOp: if ast.op is Ops.PROGRAM and len(ast.src) >= 5 and ast.src[4].op is Ops.BINARY: - return ast if isinstance(ast.arg, ProgramInfo) else ast.replace(arg=ProgramInfo.from_sink(ast.src[0], ast.src[1].arg)) - key = (ast.key, type(renderer), renderer.target, NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value) + return ast if isinstance(ast.arg, ProgramInfo) else ast.replace(arg=ProgramInfo.from_sink(ast.src[0])) + config = (NOOPT, DEVECTORIZE, EMULATED_DTYPES, NOLOCALS, USE_TC, IMAGE, DISABLE_FAST_IDIV, TRANSCENDENTAL, ALLOW_TF32) + key = (ast.key, type(renderer), renderer.target, *[x.value for x in config]) if (prg:=to_program_cache.get(key)) is None: to_program_cache[key] = prg = do_to_program(ast, renderer) return prg diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index a1385042d0..e431aebe12 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -43,13 +43,13 @@ def _time_program(prg:UOp, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buff global_size, factor = get_test_global_size(info.global_size, max_global_size, var_vals) prg = prg.replace(arg=replace(info, global_size=tuple(global_size))) if len(prg.src) <= 4 or prg.src[4].op is not Ops.BINARY: prg = prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),)) - try: car = CompiledRunner(prg) + try: car = CompiledRunner(prg, prg.src[1].arg) except AssertionError: return [math.inf] * cnt tms = [] input_bufs = [rawbufs[i] for i in car.p.globals] for _ in range(cnt): if clear_l2: - if hasattr(dev:=Device[info.device], 'invalidate_caches'): dev.invalidate_caches() + if hasattr(dev:=Device[prg.src[1].arg], 'invalidate_caches'): dev.invalidate_caches() else: with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False) tms.append(unwrap(car(input_bufs, var_vals, wait=True, timeout=timeout))*factor) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 494da38bca..7187c2ac42 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -110,7 +110,7 @@ class GraphRunner(Runner): self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} - self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {} + self.launch_dims_base:dict[int, tuple[tuple[int|float, ...], tuple[int, ...]]] = {} def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index a0d7b04305..0b45cd8d87 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,4 +1,4 @@ -from typing import cast, Callable, Iterator +from typing import cast, Iterator import time, random, itertools, math, contextlib, weakref from dataclasses import dataclass, replace, field from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey @@ -9,6 +9,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, from tinygrad.device import Device, Buffer, MultiBuffer from tinygrad.renderer import Estimates from tinygrad.codegen import to_program +from tinygrad.codegen.opt.postrange import bufs_from_ast # **************** Stat **************** @@ -55,45 +56,47 @@ class Runner: def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None: raise NotImplementedError("override this") -def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]: - test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs - MAX_WORKGROUP = 1024 - local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] - local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice - def try_exec(local_size): - try: - return _prg(*[x._buf for x in test_rawbuffers],global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], - local_size=local_size, wait=True) - except Exception: return float('inf') - ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) - assert not math.isinf(ret[0]), "all optimize_local_size exec failed" - return ret[1] +local_size_cache: dict[bytes, tuple[int, ...]] = {} +def optimize_local_size(call:UOp, prg:UOp) -> UOp|None: + device = prg.src[1].arg + if prg.arg.local_size is not None or not Device[device].renderer.has_local or not all_int(prg.arg.global_size): return None + + if (local_size:=local_size_cache.get(prg.key)) is None: + bufs = [b._buf for b in (b.allocate() for b in bufs_from_ast(prg.src[0], device))] + rt = Device[device].runtime(prg.arg.function_name, prg.src[4].arg, *prg.arg.aux, runtimevars=prg.arg.runtimevars) + def try_exec(local_size): + try: return rt(*bufs, global_size=[g//l if g%l == 0 else g/l for g,l in zip(prg.arg.global_size, local_size)], local_size=local_size, wait=True) + except Exception: return float('inf') + + MAX_WORKGROUP = 1024 + local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in prg.arg.global_size] + local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice + best_time, best = min([(try_exec(ls), ls) for ls in random.sample(local_sizes, len(local_sizes))]) + assert not math.isinf(best_time), "all optimize_local_size exec failed" + local_size = local_size_cache[prg.key] = tuple(best) + + new_global = tuple(g//l if g%l == 0 else g/l for g,l in zip(prg.arg.global_size, local_size)) + return call.replace(src=(prg.replace(arg=replace(prg.arg, global_size=new_global, local_size=local_size)), *call.src[1:])) class CompiledRunner(Runner): - def __init__(self, prg:UOp, _prg=None): + def __init__(self, prg:UOp, device:str): info: ProgramInfo = prg.arg sink = prg.src[0] if DEBUG >= 3 and sink.arg.applied_opts: print(sink.arg.applied_opts) if DEBUG >= 4: print(prg.src[3].arg) if len(prg.src) <= 4 or prg.src[4].op is not Ops.BINARY: with cpu_profile(TracingKey(f"compile {info.name}", (info.function_name,)), "TINY"): - lib = Device[info.device].compiler.compile_cached(prg.src[3].arg) + lib = Device[device].compiler.compile_cached(prg.src[3].arg) prg = prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),)) self.prg:UOp = prg self.p:ProgramInfo = info - if DEBUG >= 7: Device[info.device].compiler.disassemble(prg.src[4].arg) - self._prg = Device[info.device].runtime(info.function_name, prg.src[4].arg, *info.aux, runtimevars=info.runtimevars) if _prg is None else _prg - super().__init__(info.name, info.device, sink.arg.estimates or Estimates()) - - def __reduce__(self): return self.__class__, (self.prg,) + if DEBUG >= 7: Device[device].compiler.disassemble(prg.src[4].arg) + self._prg = Device[device].runtime(info.function_name, prg.src[4].arg, *info.aux, runtimevars=info.runtimevars) + super().__init__(info.name, device, sink.arg.estimates or Estimates()) def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False, timeout:int|None=None) -> float|None: if var_vals is None: var_vals = {} global_size, local_size = self.p.launch_dims(var_vals) - if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size): - local_size = optimize_local_size(self._prg, global_size, rawbufs) - global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] - self.p = replace(self.p, global_size=tuple(global_size), local_size=tuple(local_size)) return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None, vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait, timeout=timeout) @@ -107,10 +110,10 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: if cret:=method_cache.get(ckey): return cret bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True) if bret:=method_cache.get(bkey): - method_cache[ckey] = ret = CompiledRunner(bret.prg.replace(arg=replace(bret.p, device=device))) + method_cache[ckey] = ret = CompiledRunner(bret.prg, device) else: prg = to_program(ast, Device[device].renderer) - method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg.replace(arg=replace(prg.arg, device=device))) + method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg, device) return ret # **************** run linear **************** @@ -229,6 +232,10 @@ pm_compile = PatternMatcher([ call.replace(src=(to_program(ast, Device[call.device if isinstance(call.device, str) else call.device[0]].renderer), *call.src[1:]))), ]) +pm_optimize_local_size = PatternMatcher([ + (UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="prg"),), name="call", allow_any_len=True), optimize_local_size), +]) + pm_exec = PatternMatcher([ (UPat(Ops.CALL, src=(UPat(Ops.BUFFER_VIEW, name="ast"),), name="call", allow_any_len=True), exec_view), (UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy), @@ -241,7 +248,8 @@ pm_exec = PatternMatcher([ def compile_linear(linear:UOp, beam=0, validate=False) -> UOp: if validate: linear = graph_rewrite(linear, pm_validate, name="validate", walk=True) if (beam_val:=(beam or BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=beam_val, walk=True) - return graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True) + linear = graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True) + return graph_rewrite(linear, pm_optimize_local_size, name="optimize local size", walk=True) def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False): if not jit: linear = compile_linear(linear, validate=VALIDATE_WITH_CPU) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 50cb0a782f..4bee02eac1 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -172,7 +172,7 @@ class HCQGraph(MultiGraphRunner): # Encode main commands based on ji type. if prg is not None: - enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1))) + enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1))) # type: ignore[arg-type] elif j in self.rdma_deps: dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j] for sig, val in dest_deps: dest_queue.wait(sig, val) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e51cbbe14c..0c318d6954 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -969,8 +969,7 @@ class KernelInfo: @dataclass(frozen=True) class ProgramInfo: name: str = "test" - device: str = "" - global_size: tuple[int, ...] = (1, 1, 1) + global_size: tuple[int|float, ...] = (1, 1, 1) local_size: tuple[int, ...]|None = None vars: tuple[UOp, ...] = () globals: tuple[int, ...] = () @@ -985,12 +984,12 @@ class ProgramInfo: def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'} def launch_dims(self, var_vals:dict[str, int]): - global_size = [sym_infer(sz, var_vals) for sz in self.global_size] + global_size = [sym_infer(sz, var_vals) for sz in self.global_size] # type: ignore[arg-type] local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None return global_size, local_size @staticmethod - def from_sink(sink:UOp, device:str, aux:tuple=()) -> ProgramInfo: + def from_sink(sink:UOp, aux:tuple=()) -> ProgramInfo: _vars: list[UOp] = [] _globals: list[int] = [] outs: list[int] = [] @@ -1008,7 +1007,7 @@ class ProgramInfo: special_size = local_size if u.arg[0] == 'l' else global_size if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) if u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': global_size[0] = u.arg[2] + 1 - return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", device, tuple(global_size), + return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size), tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.arg)), tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)