mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
local optimize as rewrite (#15953)
* local optimize as rewrite * better * x * slighly rename * fix * ugh * remove * x * remove * not weak
This commit is contained in:
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -623,6 +623,8 @@ jobs:
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: openpilot compile3 0.11.0 driving_vision
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.11.0 driving_vision (from pickle)
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_vision_run_pickle RUN_PICKLE=1 PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM taskset -c 4-7 python3 examples/openpilot/compile3.py
|
||||
- name: IR3 openpilot compile3 0.11.0 driving_vision
|
||||
run: BENCHMARK_LOG=ir3_openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM:IR3 FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.11.0 driving_policy
|
||||
@@ -668,6 +670,8 @@ jobs:
|
||||
run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision PYTHONPATH="." GMMU=0 DEV=USB+AMD:LLVM ASSERT_MIN_STEP_TIME=50 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot load_pickle 0.10.1 driving_vision
|
||||
run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision_load_pickle PYTHONPATH="." GMMU=0 DEV=USB+AMD ASSERT_MIN_LOAD_TIME=15 python3 examples/openpilot/load_pickle.py
|
||||
- name: openpilot run_pickle 0.10.1 driving_vision
|
||||
run: BENCHMARK_LOG=usbgpu_openpilot_0_10_1_vision_run_pickle RUN_PICKLE=1 PYTHONPATH="." GMMU=0 DEV=USB+AMD ASSERT_MIN_STEP_TIME=50 python3 examples/openpilot/compile3.py
|
||||
|
||||
testreddriverbenchmark:
|
||||
name: AM Benchmark
|
||||
|
||||
@@ -133,14 +133,20 @@ def bench(run, inputs):
|
||||
run(**inputs).numpy()
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
inputs, outputs = compile(onnx_file)
|
||||
if getenv("RUN_PICKLE"):
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
inputs = {name: Tensor(Tensor.randn(*[int(s) for s in view.src[1].arg], dtype=dtype).numpy(), device=device)
|
||||
for name, (view, _vars, dtype, device) in zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_input_info)}
|
||||
test_vs_compile(pickle_loaded, inputs)
|
||||
else:
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
inputs, outputs = compile(onnx_file)
|
||||
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
|
||||
test_vs_compile(pickle_loaded, inputs, outputs)
|
||||
if getenv("SELFTEST"):
|
||||
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
|
||||
test_vs_compile(pickle_loaded, inputs, outputs)
|
||||
if getenv("SELFTEST"):
|
||||
test_vs_onnx(inputs, outputs, onnx_file, 1e-4)
|
||||
|
||||
if getenv("BENCHMARK_LOG", ""):
|
||||
bench(pickle_loaded, inputs)
|
||||
|
||||
@@ -89,11 +89,11 @@ if __name__ == "__main__":
|
||||
# remove debug sections
|
||||
src = src.split("\t.file")[0]
|
||||
assert '.extern .shared' not in src
|
||||
info = ProgramInfo(name="matmul_kernel", device=Device.DEFAULT,
|
||||
info = ProgramInfo(name="matmul_kernel",
|
||||
global_size=(M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1), local_size=(32*compiled.metadata.num_warps, 1, 1))
|
||||
sink = UOp.sink(arg=KernelInfo(name="matmul_kernel"))
|
||||
prg_uop = UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR), UOp(Ops.SOURCE, arg=src)), arg=info)
|
||||
runner = CompiledRunner(prg_uop)
|
||||
runner = CompiledRunner(prg_uop, Device.DEFAULT)
|
||||
all_bufs = [x.ensure_allocated() for x in bufs]
|
||||
prg_bufs = [all_bufs[i] for i in runner.p.globals]
|
||||
tflops = []
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
|
||||
@@ -430,7 +429,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
|
||||
|
||||
def get_prg(opts):
|
||||
ast = realized_ast if opts is None else replace_opts(realized_ast, list(opts))
|
||||
return CompiledRunner((pu:=to_program(ast, renderer=Device[Device.DEFAULT].renderer)).replace(arg=replace(pu.arg, device=device)))
|
||||
return CompiledRunner(to_program(ast, renderer=Device[Device.DEFAULT].renderer), device)
|
||||
|
||||
def check_opt(opts):
|
||||
prg = get_prg(opts=opts)
|
||||
|
||||
@@ -19,9 +19,9 @@ def _test_uop_result(inputs:list[Tensor], prg:UOp, local_size=None):
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
info = replace(prg.arg, device=Device.DEFAULT)
|
||||
info = prg.arg
|
||||
if local_size is not None: info = replace(info, local_size=tuple(local_size))
|
||||
ei = CompiledRunner(prg.replace(arg=info))
|
||||
ei = CompiledRunner(prg.replace(arg=info), Device.DEFAULT)
|
||||
ei.exec(outbufs+inbufs)
|
||||
return [np.frombuffer(x.as_memoryview(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
|
||||
@@ -13,11 +13,10 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from test.helpers import to_uops_list
|
||||
from dataclasses import replace
|
||||
|
||||
def _uops_to_prg(uops_list):
|
||||
prg = to_program(UOp.sink(*uops_list, arg=KernelInfo()), Device[Device.DEFAULT].renderer)
|
||||
return CompiledRunner(prg.replace(arg=replace(prg.arg, device=Device.DEFAULT)))
|
||||
return CompiledRunner(prg, Device.DEFAULT)
|
||||
|
||||
def uop(uops:list[UOp], op:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
if op is Ops.CONST: uops.append(UOp.const(dtype, arg))
|
||||
|
||||
@@ -166,7 +166,8 @@ class TestHCQ(unittest.TestCase):
|
||||
b = a + 1
|
||||
si = b.schedule_linear().src[-1]
|
||||
|
||||
runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer))
|
||||
runner = CompiledRunner(to_program(replace_opts(si.src[0], [Opt(op=OptOps.LOCAL, axis=0, arg=3) for _ in range(3)]), TestHCQ.d0.renderer),
|
||||
Device.DEFAULT)
|
||||
|
||||
zb = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
zt = Buffer(Device.DEFAULT, 3 * 3 * 3, dtypes.int, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
|
||||
2
test/external/external_benchmark_op_conv.py
vendored
2
test/external/external_benchmark_op_conv.py
vendored
@@ -90,7 +90,7 @@ renderer = Device.default.renderer
|
||||
allocator = Device.default.allocator
|
||||
|
||||
ps = to_program(ast, renderer)
|
||||
cr = CompiledRunner(ps.replace(arg=replace(ps.arg, device=Device.DEFAULT)))
|
||||
cr = CompiledRunner(ps, Device.DEFAULT)
|
||||
|
||||
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.PARAM]), key=lambda u: u.arg)
|
||||
# print(len(gs))
|
||||
|
||||
3
test/external/external_uop_gc.py
vendored
3
test/external/external_uop_gc.py
vendored
@@ -2,7 +2,7 @@ import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.schedule import schedule_cache
|
||||
from tinygrad.engine.realize import method_cache
|
||||
from tinygrad.codegen import to_program
|
||||
from tinygrad.codegen import to_program, to_program_cache
|
||||
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
from test.test_tiny import TestTiny
|
||||
@@ -72,6 +72,7 @@ if __name__ == "__main__":
|
||||
# these caches will keep uops alive
|
||||
schedule_cache.clear()
|
||||
method_cache.clear()
|
||||
to_program_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
_apply_reshape.cache_clear()
|
||||
fold_divmod_general.cache_clear()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
@@ -51,7 +50,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
|
||||
pu = to_program(replace_opts(realized_ast, opts), Device[Device.DEFAULT].renderer)
|
||||
if use_tensor_cores == 1: assert len([uop for uop in pu.src[2].src if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
|
||||
assert len([x for x in pu.src[0].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
prg = CompiledRunner(pu.replace(arg=replace(pu.arg, device=Device.DEFAULT)))
|
||||
prg = CompiledRunner(pu, Device.DEFAULT)
|
||||
prg.exec(bufs)
|
||||
if dtype_in == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
|
||||
elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = (1e-1, 2e-2) if dtype_out == dtypes.bfloat16 else (1e-2, 1e-2)
|
||||
@@ -150,7 +149,7 @@ class TestTensorCores(unittest.TestCase):
|
||||
assert len([uop for uop in tuple(program.src[2].src) if uop.op is Ops.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in program.src[0].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
|
||||
prg = CompiledRunner(program)
|
||||
prg = CompiledRunner(program, Device.DEFAULT)
|
||||
# TODO: support this even if numpy doesn't
|
||||
if _to_np_dtype(real_bufs[0].dtype) is None: continue
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import cast
|
||||
from dataclasses import replace
|
||||
import itertools, weakref
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES
|
||||
from tinygrad.helpers import TracingKey, Context, Target, panic
|
||||
import itertools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, Target, panic
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
@@ -171,18 +171,18 @@ def do_to_program(ast:UOp, renderer:Renderer) -> UOp:
|
||||
elif ast.op is Ops.SINK:
|
||||
assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to to_program"
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None, beam=ast.arg.beam)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)),
|
||||
arg=ProgramInfo.from_sink(full_sink, renderer.target.device))
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)), arg=ProgramInfo.from_sink(full_sink))
|
||||
else: raise RuntimeError(f"can't call to_program on {ast.op}")
|
||||
if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0], prg.src[1].arg))
|
||||
if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0]))
|
||||
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
|
||||
if VIZ: graph_rewrite(prg, PatternMatcher([]), name="View Program")
|
||||
return prg
|
||||
|
||||
to_program_cache: weakref.WeakValueDictionary[tuple, UOp] = weakref.WeakValueDictionary()
|
||||
to_program_cache: dict[tuple, UOp] = {}
|
||||
def to_program(ast:UOp, renderer:Renderer) -> UOp:
|
||||
if ast.op is Ops.PROGRAM and len(ast.src) >= 5 and ast.src[4].op is Ops.BINARY:
|
||||
return ast if isinstance(ast.arg, ProgramInfo) else ast.replace(arg=ProgramInfo.from_sink(ast.src[0], ast.src[1].arg))
|
||||
key = (ast.key, type(renderer), renderer.target, NOOPT.value, DEVECTORIZE.value, EMULATED_DTYPES.value)
|
||||
return ast if isinstance(ast.arg, ProgramInfo) else ast.replace(arg=ProgramInfo.from_sink(ast.src[0]))
|
||||
config = (NOOPT, DEVECTORIZE, EMULATED_DTYPES, NOLOCALS, USE_TC, IMAGE, DISABLE_FAST_IDIV, TRANSCENDENTAL, ALLOW_TF32)
|
||||
key = (ast.key, type(renderer), renderer.target, *[x.value for x in config])
|
||||
if (prg:=to_program_cache.get(key)) is None: to_program_cache[key] = prg = do_to_program(ast, renderer)
|
||||
return prg
|
||||
|
||||
@@ -43,13 +43,13 @@ def _time_program(prg:UOp, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buff
|
||||
global_size, factor = get_test_global_size(info.global_size, max_global_size, var_vals)
|
||||
prg = prg.replace(arg=replace(info, global_size=tuple(global_size)))
|
||||
if len(prg.src) <= 4 or prg.src[4].op is not Ops.BINARY: prg = prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
|
||||
try: car = CompiledRunner(prg)
|
||||
try: car = CompiledRunner(prg, prg.src[1].arg)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
input_bufs = [rawbufs[i] for i in car.p.globals]
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
if hasattr(dev:=Device[info.device], 'invalidate_caches'): dev.invalidate_caches()
|
||||
if hasattr(dev:=Device[prg.src[1].arg], 'invalidate_caches'): dev.invalidate_caches()
|
||||
else:
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
tms.append(unwrap(car(input_bufs, var_vals, wait=True, timeout=timeout))*factor)
|
||||
|
||||
@@ -110,7 +110,7 @@ class GraphRunner(Runner):
|
||||
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int|float, ...], tuple[int, ...]]] = {}
|
||||
|
||||
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast, Callable, Iterator
|
||||
from typing import cast, Iterator
|
||||
import time, random, itertools, math, contextlib, weakref
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
@@ -9,6 +9,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers,
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen import to_program
|
||||
from tinygrad.codegen.opt.postrange import bufs_from_ast
|
||||
|
||||
# **************** Stat ****************
|
||||
|
||||
@@ -55,45 +56,47 @@ class Runner:
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False) -> float|None:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
||||
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
||||
MAX_WORKGROUP = 1024
|
||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
def try_exec(local_size):
|
||||
try:
|
||||
return _prg(*[x._buf for x in test_rawbuffers],global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)],
|
||||
local_size=local_size, wait=True)
|
||||
except Exception: return float('inf')
|
||||
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
||||
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
||||
return ret[1]
|
||||
local_size_cache: dict[bytes, tuple[int, ...]] = {}
|
||||
def optimize_local_size(call:UOp, prg:UOp) -> UOp|None:
|
||||
device = prg.src[1].arg
|
||||
if prg.arg.local_size is not None or not Device[device].renderer.has_local or not all_int(prg.arg.global_size): return None
|
||||
|
||||
if (local_size:=local_size_cache.get(prg.key)) is None:
|
||||
bufs = [b._buf for b in (b.allocate() for b in bufs_from_ast(prg.src[0], device))]
|
||||
rt = Device[device].runtime(prg.arg.function_name, prg.src[4].arg, *prg.arg.aux, runtimevars=prg.arg.runtimevars)
|
||||
def try_exec(local_size):
|
||||
try: return rt(*bufs, global_size=[g//l if g%l == 0 else g/l for g,l in zip(prg.arg.global_size, local_size)], local_size=local_size, wait=True)
|
||||
except Exception: return float('inf')
|
||||
|
||||
MAX_WORKGROUP = 1024
|
||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in prg.arg.global_size]
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
best_time, best = min([(try_exec(ls), ls) for ls in random.sample(local_sizes, len(local_sizes))])
|
||||
assert not math.isinf(best_time), "all optimize_local_size exec failed"
|
||||
local_size = local_size_cache[prg.key] = tuple(best)
|
||||
|
||||
new_global = tuple(g//l if g%l == 0 else g/l for g,l in zip(prg.arg.global_size, local_size))
|
||||
return call.replace(src=(prg.replace(arg=replace(prg.arg, global_size=new_global, local_size=local_size)), *call.src[1:]))
|
||||
|
||||
class CompiledRunner(Runner):
|
||||
def __init__(self, prg:UOp, _prg=None):
|
||||
def __init__(self, prg:UOp, device:str):
|
||||
info: ProgramInfo = prg.arg
|
||||
sink = prg.src[0]
|
||||
if DEBUG >= 3 and sink.arg.applied_opts: print(sink.arg.applied_opts)
|
||||
if DEBUG >= 4: print(prg.src[3].arg)
|
||||
if len(prg.src) <= 4 or prg.src[4].op is not Ops.BINARY:
|
||||
with cpu_profile(TracingKey(f"compile {info.name}", (info.function_name,)), "TINY"):
|
||||
lib = Device[info.device].compiler.compile_cached(prg.src[3].arg)
|
||||
lib = Device[device].compiler.compile_cached(prg.src[3].arg)
|
||||
prg = prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),))
|
||||
self.prg:UOp = prg
|
||||
self.p:ProgramInfo = info
|
||||
if DEBUG >= 7: Device[info.device].compiler.disassemble(prg.src[4].arg)
|
||||
self._prg = Device[info.device].runtime(info.function_name, prg.src[4].arg, *info.aux, runtimevars=info.runtimevars) if _prg is None else _prg
|
||||
super().__init__(info.name, info.device, sink.arg.estimates or Estimates())
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.prg,)
|
||||
if DEBUG >= 7: Device[device].compiler.disassemble(prg.src[4].arg)
|
||||
self._prg = Device[device].runtime(info.function_name, prg.src[4].arg, *info.aux, runtimevars=info.runtimevars)
|
||||
super().__init__(info.name, device, sink.arg.estimates or Estimates())
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False, timeout:int|None=None) -> float|None:
|
||||
if var_vals is None: var_vals = {}
|
||||
global_size, local_size = self.p.launch_dims(var_vals)
|
||||
if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size):
|
||||
local_size = optimize_local_size(self._prg, global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
self.p = replace(self.p, global_size=tuple(global_size), local_size=tuple(local_size))
|
||||
return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None,
|
||||
vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait, timeout=timeout)
|
||||
|
||||
@@ -107,10 +110,10 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (device.split(":")[0], type(Device[device].compiler), ast.key, context, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(bret.prg.replace(arg=replace(bret.p, device=device)))
|
||||
method_cache[ckey] = ret = CompiledRunner(bret.prg, device)
|
||||
else:
|
||||
prg = to_program(ast, Device[device].renderer)
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg.replace(arg=replace(prg.arg, device=device)))
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(prg, device)
|
||||
return ret
|
||||
|
||||
# **************** run linear ****************
|
||||
@@ -229,6 +232,10 @@ pm_compile = PatternMatcher([
|
||||
call.replace(src=(to_program(ast, Device[call.device if isinstance(call.device, str) else call.device[0]].renderer), *call.src[1:]))),
|
||||
])
|
||||
|
||||
pm_optimize_local_size = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="prg"),), name="call", allow_any_len=True), optimize_local_size),
|
||||
])
|
||||
|
||||
pm_exec = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.BUFFER_VIEW, name="ast"),), name="call", allow_any_len=True), exec_view),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy),
|
||||
@@ -241,7 +248,8 @@ pm_exec = PatternMatcher([
|
||||
def compile_linear(linear:UOp, beam=0, validate=False) -> UOp:
|
||||
if validate: linear = graph_rewrite(linear, pm_validate, name="validate", walk=True)
|
||||
if (beam_val:=(beam or BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=beam_val, walk=True)
|
||||
return graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True)
|
||||
linear = graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True)
|
||||
return graph_rewrite(linear, pm_optimize_local_size, name="optimize local size", walk=True)
|
||||
|
||||
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False):
|
||||
if not jit: linear = compile_linear(linear, validate=VALIDATE_WITH_CPU)
|
||||
|
||||
@@ -172,7 +172,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if prg is not None:
|
||||
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1)))
|
||||
enqueue_queue.exec(prg._prg, self.ji_args[j], tuple(prg.p.global_size or (1,1,1)), tuple(prg.p.local_size or (1,1,1))) # type: ignore[arg-type]
|
||||
elif j in self.rdma_deps:
|
||||
dest_queue, dest_deps, dest_out_signal, dest_out_val = self.rdma_deps[j]
|
||||
for sig, val in dest_deps: dest_queue.wait(sig, val)
|
||||
|
||||
@@ -969,8 +969,7 @@ class KernelInfo:
|
||||
@dataclass(frozen=True)
|
||||
class ProgramInfo:
|
||||
name: str = "test"
|
||||
device: str = ""
|
||||
global_size: tuple[int, ...] = (1, 1, 1)
|
||||
global_size: tuple[int|float, ...] = (1, 1, 1)
|
||||
local_size: tuple[int, ...]|None = None
|
||||
vars: tuple[UOp, ...] = ()
|
||||
globals: tuple[int, ...] = ()
|
||||
@@ -985,12 +984,12 @@ class ProgramInfo:
|
||||
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
|
||||
|
||||
def launch_dims(self, var_vals:dict[str, int]):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size]
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] # type: ignore[arg-type]
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||
return global_size, local_size
|
||||
|
||||
@staticmethod
|
||||
def from_sink(sink:UOp, device:str, aux:tuple=()) -> ProgramInfo:
|
||||
def from_sink(sink:UOp, aux:tuple=()) -> ProgramInfo:
|
||||
_vars: list[UOp] = []
|
||||
_globals: list[int] = []
|
||||
outs: list[int] = []
|
||||
@@ -1008,7 +1007,7 @@ class ProgramInfo:
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
if u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': global_size[0] = u.arg[2] + 1
|
||||
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", device, tuple(global_size),
|
||||
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
|
||||
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.arg)),
|
||||
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user