mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
var_vals uses str for var (#12011)
* var_vals is str,int * remove imports * remove print * fix test * change var_vals in hcq * update test_hcq * fix multitensor _device_num var * fix syminfer test * shorten line * p.vars stays list[Variable] * shorten line * vars is back to tuple[Variable, ...] * change var_vals in extra * change var_vals from shapetracker * var_vals is str:int * fix signature
This commit is contained in:
@@ -4,12 +4,11 @@ import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, dedup
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
class CUDAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
@@ -28,7 +27,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
@@ -48,7 +47,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
@@ -69,7 +69,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
|
||||
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
|
||||
self.fixedvars: dict[HCQCompiled, dict[str, int]] = {}
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
|
||||
@@ -183,7 +183,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
@@ -195,12 +195,13 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
|
||||
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
|
||||
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
|
||||
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
||||
|
||||
@@ -5,7 +5,6 @@ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
||||
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
|
||||
|
||||
@@ -17,7 +16,7 @@ class MTLResourceUsage:
|
||||
MTLResourceUsageWrite = 0b10
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
@@ -48,7 +47,8 @@ class MetalGraph(GraphRunner):
|
||||
if b is not None and b not in input_rawbuffers:
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.p.vars):
|
||||
msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
|
||||
@@ -61,7 +61,7 @@ class MetalGraph(GraphRunner):
|
||||
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
|
||||
self.range = to_struct(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
# NOTE: old command buffer may not be inflight anymore
|
||||
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time, itertools
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
@@ -18,7 +17,7 @@ def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_
|
||||
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
|
||||
|
||||
class RemoteGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, rawbufs, var_vals)
|
||||
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
|
||||
c2d = {device.conn: device for device in devices}
|
||||
@@ -93,7 +92,7 @@ class RemoteGraph(MultiGraphRunner):
|
||||
for req in self.template:
|
||||
match req:
|
||||
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
|
||||
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
|
||||
if wait: st = time.perf_counter()
|
||||
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
|
||||
for req in self.template:
|
||||
|
||||
@@ -100,7 +100,7 @@ class GraphComputeItem:
|
||||
datahash: str
|
||||
bufs: tuple[int, ...]
|
||||
vars: tuple[Variable, ...]
|
||||
fixedvars: dict[Variable, int]
|
||||
fixedvars: dict[str, int]
|
||||
ins: tuple[int, ...]
|
||||
outs: tuple[int, ...]
|
||||
global_size: tuple[sint, ...]|None
|
||||
@@ -111,7 +111,7 @@ class GraphAlloc(RemoteRequest):
|
||||
graph_num: int
|
||||
jit_cache: tuple[GraphComputeItem|Transfer, ...]
|
||||
bufs: tuple[tuple[SessionKey, int], ...]
|
||||
var_vals: dict[Variable, int]
|
||||
var_vals: dict[str, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphFree(RemoteRequest):
|
||||
@@ -121,7 +121,7 @@ class GraphFree(RemoteRequest):
|
||||
class GraphExec(RemoteRequest):
|
||||
graph_num: int
|
||||
bufs: tuple[tuple[SessionKey, int], ...]
|
||||
var_vals: dict[Variable, int]
|
||||
var_vals: dict[str, int]
|
||||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
|
||||
@@ -6,7 +6,7 @@ except ImportError: fcntl = None #type:ignore[assignment]
|
||||
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.uop.ops import sym_infer, sint, Variable, UOp
|
||||
from tinygrad.uop.ops import sym_infer, sint, UOp
|
||||
from tinygrad.runtime.autogen import libc
|
||||
|
||||
class MMIOInterface:
|
||||
@@ -192,7 +192,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
||||
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
||||
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
||||
|
||||
def _apply_var_vals(self, var_vals:dict[Variable, int]):
|
||||
def _apply_var_vals(self, var_vals:dict[str, int]):
|
||||
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
|
||||
|
||||
for off, sym_idx in self.q_sints:
|
||||
@@ -205,7 +205,7 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
||||
|
||||
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
|
||||
|
||||
def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
|
||||
def submit(self, dev:HCQDeviceType, var_vals:dict[str, int]|None=None):
|
||||
"""
|
||||
Submits the command queue to a specific device for execution.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user