local optimize as rewrite (#15953)

* local optimize as rewrite

* better

* x

* slighly rename

* fix

* ugh

* remove

* x

* remove

* not weak
This commit is contained in:
nimlgen
2026-04-28 22:51:04 +03:00
committed by GitHub
parent b3f0f8d349
commit 77965a22e5
16 changed files with 82 additions and 66 deletions

View File

@@ -623,6 +623,8 @@ jobs:
run: test/external/process_replay/reset.py
- name: openpilot compile3 0.11.0 driving_vision
run: BENCHMARK_LOG=openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.11.0 driving_vision (from pickle)
run: BENCHMARK_LOG=openpilot_0_11_0_vision_run_pickle RUN_PICKLE=1 PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM taskset -c 4-7 python3 examples/openpilot/compile3.py
- name: IR3 openpilot compile3 0.11.0 driving_vision
run: BENCHMARK_LOG=ir3_openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM:IR3 FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.11.0 driving_policy
@@ -668,6 +670,8 @@ jobs:
run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision PYTHONPATH="." GMMU=0 DEV=USB+AMD:LLVM ASSERT_MIN_STEP_TIME=50 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot load_pickle 0.10.1 driving_vision
run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision_load_pickle PYTHONPATH="." GMMU=0 DEV=USB+AMD ASSERT_MIN_LOAD_TIME=15 python3 examples/openpilot/load_pickle.py
- name: openpilot run_pickle 0.10.1 driving_vision
run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision_run_pickle RUN_PICKLE=1 PYTHONPATH="." GMMU=0 DEV=USB+AMD ASSERT_MIN_STEP_TIME=50 python3 examples/openpilot/compile3.py
testreddriverbenchmark:
name: AM Benchmark

View File

@@ -133,14 +133,20 @@ def bench(run, inputs):
run(**inputs).numpy()
if __name__ == "__main__":
onnx_file = fetch(OPENPILOT_MODEL)
inputs, outputs = compile(onnx_file)
if getenv("RUN_PICKLE"):
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
inputs = {name: Tensor(Tensor.randn(*[int(s) for s in view.src[1].arg], dtype=dtype).numpy(), device=device)
for name, (view, _vars, dtype, device) in zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_input_info)}
test_vs_compile(pickle_loaded, inputs)
else:
onnx_file = fetch(OPENPILOT_MODEL)
inputs, outputs = compile(onnx_file)
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
test_vs_compile(pickle_loaded, inputs, outputs)
if getenv("SELFTEST"):
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
test_vs_compile(pickle_loaded, inputs, outputs)
if getenv("SELFTEST"):
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
if getenv("BENCHMARK_LOG", ""):
bench(pickle_loaded, inputs)

View File

@@ -89,11 +89,11 @@ if __name__ == "__main__":
# remove debug sections
src = src.split("\t.file")[0]
assert '.extern .shared' not in src
info = ProgramInfo(name="matmul_kernel", device=Device.DEFAULT,
info = ProgramInfo(name="matmul_kernel",
global_size=(M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1), local_size=(32*compiled.metadata.num_warps, 1, 1))
sink = UOp.sink(arg=KernelInfo(name="matmul_kernel"))
prg_uop = UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR), UOp(Ops.SOURCE, arg=src)), arg=info)
runner = CompiledRunner(prg_uop)
runner = CompiledRunner(prg_uop, Device.DEFAULT)
all_bufs = [x.ensure_allocated() for x in bufs]
prg_bufs = [all_bufs[i] for i in runner.p.globals]
tflops = []

View File

@@ -1,6 +1,5 @@
import numpy as np
import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
@@ -430,7 +429,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
def get_prg(opts):
ast = realized_ast if opts is None else replace_opts(realized_ast, list(opts))
return CompiledRunner((pu:=to_program(ast, renderer=Device[Device.DEFAULT].renderer)).replace(arg=replace(pu.arg, device=device)))
return CompiledRunner(to_program(ast, renderer=Device[Device.DEFAULT].renderer), device)
def check_opt(opts):
prg = get_prg(opts=opts)

View File

@@ -19,9 +19,9 @@ def _test_uop_result(inputs:list[Tensor], prg:UOp, local_size=None):
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [x.uop.base.buffer for x in inputs]
info = replace(prg.arg, device=Device.DEFAULT)
info = prg.arg
if local_size is not None: info = replace(info, local_size=tuple(local_size))
ei = CompiledRunner(prg.replace(arg=info))
ei = CompiledRunner(prg.replace(arg=info), Device.DEFAULT)
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_memoryview(), _to_np_dtype(x.dtype)) for x in outbufs]

View File

@@ -13,11 +13,10 @@ from tinygrad.device import is_dtype_supported
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.renderer.ptx import PTXRenderer
from test.helpers import to_uops_list
from dataclasses import replace
def _uops_to_prg(uops_list):
prg = to_program(UOp.sink(*uops_list, arg=KernelInfo()), Device[Device.DEFAULT].renderer)
return CompiledRunner(prg.replace(arg=replace(prg.arg, device=Device.DEFAULT)))
return CompiledRunner(prg, Device.DEFAULT)
def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
if op is Ops.CONST: uops.append(UOp.const(dtype, arg))

View File

@@ -166,7 +166,8 @@ class TestHCQ(unittest.TestCase):
b = a + 1
si = b.schedule_linear().src[-1]
runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer))
runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer),
Device.DEFAULT)
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()

View File

@@ -90,7 +90,7 @@ renderer = Device.default.renderer
allocator = Device.default.allocator
ps = to_program(ast, renderer)
cr = CompiledRunner(ps.replace(arg=replace(ps.arg, device=Device.DEFAULT)))
cr = CompiledRunner(ps, Device.DEFAULT)
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.PARAM]), key=lambda u: u.arg)
# print(len(gs))

View File

@@ -2,7 +2,7 @@ import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.schedule import schedule_cache
from tinygrad.engine.realize import method_cache
from tinygrad.codegen import to_program
from tinygrad.codegen import to_program, to_program_cache
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
from test.test_tiny import TestTiny
@@ -72,6 +72,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
schedule_cache.clear()
method_cache.clear()
to_program_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()
fold_divmod_general.cache_clear()

View File

@@ -1,6 +1,5 @@
import numpy as np
import unittest
from dataclasses import replace
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
@@ -51,7 +50,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
pu = to_program(replace_opts(realized_ast, opts), Device[Device.DEFAULT].renderer)
if use_tensor_cores == 1: assert len([uop for uop in pu.src[2].src if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
assert len([x for x in pu.src[0].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = CompiledRunner(pu.replace(arg=replace(pu.arg, device=Device.DEFAULT)))
prg = CompiledRunner(pu, Device.DEFAULT)
prg.exec(bufs)
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = (1e-1, 2e-2) if dtype_out == dtypes.bfloat16 else (1e-2, 1e-2)
@@ -150,7 +149,7 @@ class TestTensorCores(unittest.TestCase):
assert len([uop for uop in tuple(program.src[2].src) if uop.op is Ops.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in program.src[0].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = CompiledRunner(program)
prg = CompiledRunner(program, Device.DEFAULT)
# TODO: support this even if numpy doesn't
if _to_np_dtype(real_bufs[0].dtype) is None: continue
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled

View File

@@ -1,8 +1,8 @@
from typing import cast
from dataclasses import replace
import itertools, weakref
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES
from tinygrad.helpers import TracingKey, Context, Target, panic
import itertools
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, Target, panic
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, pyrender
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer, Estimates
@@ -171,18 +171,18 @@ def do_to_program(ast:UOp, renderer:Renderer) -> UOp:
elif ast.op is Ops.SINK:
assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to to_program"
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None, beam=ast.arg.beam)
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)),
arg=ProgramInfo.from_sink(full_sink, renderer.target.device))
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)), arg=ProgramInfo.from_sink(full_sink))
else: raise RuntimeError(f"can't call to_program on {ast.op}")
if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0], prg.src[1].arg))
if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0]))
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
if VIZ: graph_rewrite(prg, PatternMatcher([]), name="View Program")
return prg
to_program_cache: weakref.WeakValueDictionary[tuple, UOp] = weakref.WeakValueDictionary()
to_program_cache: dict[tuple, UOp] = {}
def to_program(ast:UOp, renderer:Renderer) -> UOp:
if ast.op is Ops.PROGRAM and len(ast.src) >= 5 and ast.src[4].op is Ops.BINARY:
return ast if isinstance(ast.arg, ProgramInfo) else ast.replace(arg=ProgramInfo.from_sink(ast.src[0], ast.src[1].arg))
key = (ast.key, type(renderer), renderer.target, NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
return ast if isinstance(ast.arg, ProgramInfo) else ast.replace(arg=ProgramInfo.from_sink(ast.src[0]))
config = (NOOPT, DEVECTORIZE, EMULATED_DTYPES, NOLOCALS, USE_TC, IMAGE, DISABLE_FAST_IDIV, TRANSCENDENTAL, ALLOW_TF32)
key = (ast.key, type(renderer), renderer.target, *[x.value for x in config])
if (prg:=to_program_cache.get(key)) is None: to_program_cache[key] = prg = do_to_program(ast, renderer)
return prg

View File

@@ -43,13 +43,13 @@ def _time_program(prg:UOp, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buff
global_size, factor = get_test_global_size(info.global_size, max_global_size, var_vals)
prg = prg.replace(arg=replace(info, global_size=tuple(global_size)))
if len(prg.src) <= 4 or prg.src[4].op is not Ops.BINARY: prg = prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
try: car = CompiledRunner(prg)
try: car = CompiledRunner(prg, prg.src[1].arg)
except AssertionError: return [math.inf] * cnt
tms = []
input_bufs = [rawbufs[i] for i in car.p.globals]
for _ in range(cnt):
if clear_l2:
if hasattr(dev:=Device[info.device], 'invalidate_caches'): dev.invalidate_caches()
if hasattr(dev:=Device[prg.src[1].arg], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(unwrap(car(input_bufs, var_vals, wait=True, timeout=timeout))*factor)

View File

@@ -110,7 +110,7 @@ class GraphRunner(Runner):
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int|float, ...], tuple[int, ...]]] = {}
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)

View File

@@ -1,4 +1,4 @@
from typing import cast, Callable, Iterator
from typing import cast, Iterator
import time, random, itertools, math, contextlib, weakref
from dataclasses import dataclass, replace, field
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
@@ -9,6 +9,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers,
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.renderer import Estimates
from tinygrad.codegen import to_program
from tinygrad.codegen.opt.postrange import bufs_from_ast
# **************** Stat ****************
@@ -55,45 +56,47 @@ class Runner:
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
raise NotImplementedError("override this")
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return _prg(*[x._buf for x in test_rawbuffers],global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)],
local_size=local_size, wait=True)
except Exception: return float('inf')
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]
local_size_cache: dict[bytes, tuple[int, ...]] = {}
def optimize_local_size(call:UOp, prg:UOp) -> UOp|None:
device = prg.src[1].arg
if prg.arg.local_size is not None or not Device[device].renderer.has_local or not all_int(prg.arg.global_size): return None
if (local_size:=local_size_cache.get(prg.key)) is None:
bufs = [b._buf for b in (b.allocate() for b in bufs_from_ast(prg.src[0], device))]
rt = Device[device].runtime(prg.arg.function_name, prg.src[4].arg, *prg.arg.aux, runtimevars=prg.arg.runtimevars)
def try_exec(local_size):
try: return rt(*bufs, global_size=[g//l if g%l == 0 else g/l for g,l in zip(prg.arg.global_size, local_size)], local_size=local_size, wait=True)
except Exception: return float('inf')
MAX_WORKGROUP = 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in prg.arg.global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
best_time, best = min([(try_exec(ls), ls) for ls in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(best_time), "all optimize_local_size exec failed"
local_size = local_size_cache[prg.key] = tuple(best)
new_global = tuple(g//l if g%l == 0 else g/l for g,l in zip(prg.arg.global_size, local_size))
return call.replace(src=(prg.replace(arg=replace(prg.arg, global_size=new_global, local_size=local_size)), *call.src[1:]))
class CompiledRunner(Runner):
def __init__(self, prg:UOp, _prg=None):
def __init__(self, prg:UOp, device:str):
info: ProgramInfo = prg.arg
sink = prg.src[0]
if DEBUG >= 3 and sink.arg.applied_opts: print(sink.arg.applied_opts)
if DEBUG >= 4: print(prg.src[3].arg)
if len(prg.src) <= 4 or prg.src[4].op is not Ops.BINARY:
with cpu_profile(TracingKey(f"compile {info.name}", (info.function_name,)), "TINY"):
lib = Device[info.device].compiler.compile_cached(prg.src[3].arg)
lib = Device[device].compiler.compile_cached(prg.src[3].arg)
prg = prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
self.prg:UOp = prg
self.p:ProgramInfo = info
if DEBUG >= 7: Device[info.device].compiler.disassemble(prg.src[4].arg)
self._prg = Device[info.device].runtime(info.function_name, prg.src[4].arg, *info.aux, runtimevars=info.runtimevars) if _prg is None else _prg
super().__init__(info.name, info.device, sink.arg.estimates or Estimates())
def __reduce__(self): return self.__class__, (self.prg,)
if DEBUG >= 7: Device[device].compiler.disassemble(prg.src[4].arg)
self._prg = Device[device].runtime(info.function_name, prg.src[4].arg, *info.aux, runtimevars=info.runtimevars)
super().__init__(info.name, device, sink.arg.estimates or Estimates())
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False, timeout:int|None=None) -> float|None:
if var_vals is None: var_vals = {}
global_size, local_size = self.p.launch_dims(var_vals)
if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size):
local_size = optimize_local_size(self._prg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
self.p = replace(self.p, global_size=tuple(global_size), local_size=tuple(local_size))
return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None,
vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait, timeout=timeout)
@@ -107,10 +110,10 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
if cret:=method_cache.get(ckey): return cret
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(bret.prg.replace(arg=replace(bret.p, device=device)))
method_cache[ckey] = ret = CompiledRunner(bret.prg, device)
else:
prg = to_program(ast, Device[device].renderer)
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg.replace(arg=replace(prg.arg, device=device)))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg, device)
return ret
# **************** run linear ****************
@@ -229,6 +232,10 @@ pm_compile = PatternMatcher([
call.replace(src=(to_program(ast, Device[call.device if isinstance(call.device, str) else call.device[0]].renderer), *call.src[1:]))),
])
pm_optimize_local_size = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="prg"),), name="call", allow_any_len=True), optimize_local_size),
])
pm_exec = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.BUFFER_VIEW, name="ast"),), name="call", allow_any_len=True), exec_view),
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy),
@@ -241,7 +248,8 @@ pm_exec = PatternMatcher([
def compile_linear(linear:UOp, beam=0, validate=False) -> UOp:
if validate: linear = graph_rewrite(linear, pm_validate, name="validate", walk=True)
if (beam_val:=(beam or BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=beam_val, walk=True)
return graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True)
linear = graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True)
return graph_rewrite(linear, pm_optimize_local_size, name="optimize local size", walk=True)
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False):
if not jit: linear = compile_linear(linear, validate=VALIDATE_WITH_CPU)

View File

@@ -172,7 +172,7 @@ class HCQGraph(MultiGraphRunner):
# Encode main commands based on ji type.
if prg is not None:
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1)))
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1))) # type: ignore[arg-type]
elif j in self.rdma_deps:
dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j]
for sig, val in dest_deps: dest_queue.wait(sig, val)

View File

@@ -969,8 +969,7 @@ class KernelInfo:
@dataclass(frozen=True)
class ProgramInfo:
name: str = "test"
device: str = ""
global_size: tuple[int, ...] = (1, 1, 1)
global_size: tuple[int|float, ...] = (1, 1, 1)
local_size: tuple[int, ...]|None = None
vars: tuple[UOp, ...] = ()
globals: tuple[int, ...] = ()
@@ -985,12 +984,12 @@ class ProgramInfo:
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
def launch_dims(self, var_vals:dict[str, int]):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size]
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] # type: ignore[arg-type]
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size
@staticmethod
def from_sink(sink:UOp, device:str, aux:tuple=()) -> ProgramInfo:
def from_sink(sink:UOp, aux:tuple=()) -> ProgramInfo:
_vars: list[UOp] = []
_globals: list[int] = []
outs: list[int] = []
@@ -1008,7 +1007,7 @@ class ProgramInfo:
special_size = local_size if u.arg[0] == 'l' else global_size
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
if u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': global_size[0] = u.arg[2] + 1
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", device, tuple(global_size),
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.arg)),
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)