var_vals uses str for var (#12011)

* var_vals is str,int

* remove imports

* remove print

* fix test

* change var_vals in hcq

* update test_hcq

* fix multitensor _device_num var

* fix syminfer test

* shorten line

* p.vars stays list[Variable]

* shorten line

* vars is back to tuple[Variable, ...]

* change var_vals in extra

* change var_vals from shapetracker

* var_vals is str:int

* fix signature
This commit is contained in:
Sieds Lykles
2025-09-06 04:16:12 +02:00
committed by GitHub
parent 8658a97197
commit c6c16b2946
24 changed files with 90 additions and 91 deletions

View File

@@ -4,12 +4,11 @@ import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, dedup
from tinygrad.device import Buffer, Device
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.uop.ops import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner, GraphException
class CUDAGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
# Check all jit items are compatible.
@@ -28,7 +27,7 @@ class CUDAGraph(MultiGraphRunner):
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
@@ -48,7 +47,7 @@ class CUDAGraph(MultiGraphRunner):
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
@@ -69,7 +69,7 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
self.fixedvars: dict[HCQCompiled, dict[str, int]] = {}
for j,ji in enumerate(jit_cache):
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
@@ -183,7 +183,7 @@ class HCQGraph(MultiGraphRunner):
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
@@ -195,12 +195,13 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
for (j,i),input_idx in self.input_replace.items():
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))

View File

@@ -5,7 +5,6 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.uop.ops import Variable
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
@@ -17,7 +16,7 @@ class MTLResourceUsage:
MTLResourceUsageWrite = 0b10
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
@@ -48,7 +47,8 @@ class MetalGraph(GraphRunner):
if b is not None and b not in input_rawbuffers:
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.p.vars):
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
@@ -61,7 +61,7 @@ class MetalGraph(GraphRunner):
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
self.range = to_struct(0, len(jit_cache))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
# NOTE: old command buffer may not be inflight anymore
if self.command_buffer is not None and PROFILE: self.collect_timestamps()

View File

@@ -1,5 +1,4 @@
import time, itertools
from tinygrad.uop.ops import Variable
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
from tinygrad.device import Device, Compiled, Buffer
@@ -18,7 +17,7 @@ def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
class RemoteGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, rawbufs, var_vals)
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
c2d = {device.conn: device for device in devices}
@@ -93,7 +92,7 @@ class RemoteGraph(MultiGraphRunner):
for req in self.template:
match req:
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
if wait: st = time.perf_counter()
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
for req in self.template:

View File

@@ -100,7 +100,7 @@ class GraphComputeItem:
datahash: str
bufs: tuple[int, ...]
vars: tuple[Variable, ...]
fixedvars: dict[Variable, int]
fixedvars: dict[str, int]
ins: tuple[int, ...]
outs: tuple[int, ...]
global_size: tuple[sint, ...]|None
@@ -111,7 +111,7 @@ class GraphAlloc(RemoteRequest):
graph_num: int
jit_cache: tuple[GraphComputeItem|Transfer, ...]
bufs: tuple[tuple[SessionKey, int], ...]
var_vals: dict[Variable, int]
var_vals: dict[str, int]
@dataclass(frozen=True)
class GraphFree(RemoteRequest):
@@ -121,7 +121,7 @@ class GraphFree(RemoteRequest):
class GraphExec(RemoteRequest):
graph_num: int
bufs: tuple[tuple[SessionKey, int], ...]
var_vals: dict[Variable, int]
var_vals: dict[str, int]
wait: bool
# for safe deserialization

View File

@@ -6,7 +6,7 @@ except ImportError: fcntl = None #type:ignore[assignment]
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.uop.ops import sym_infer, sint, Variable, UOp
from tinygrad.uop.ops import sym_infer, sint, UOp
from tinygrad.runtime.autogen import libc
class MMIOInterface:
@@ -192,7 +192,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
def _apply_var_vals(self, var_vals:dict[Variable, int]):
def _apply_var_vals(self, var_vals:dict[str, int]):
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
for off, sym_idx in self.q_sints:
@@ -205,7 +205,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
def submit(self, dev:HCQDeviceType, var_vals:dict[str, int]|None=None):
"""
Submits the command queue to a specific device for execution.