From bf7d1fcd2ce3626e63a95869e654045b1f6d2dd6 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:30:06 +0300 Subject: [PATCH] tiny import fixes in hcq graph (#8184) --- tinygrad/runtime/graph/hcq.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 7d14b0ade8..44d64f811b 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -3,22 +3,22 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set from tinygrad.helpers import round_up, PROFILE, memsize_to_str from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator from tinygrad.device import Buffer, BufferSpec, Compiled, Device -from tinygrad import Variable, dtypes -from tinygrad.ops import Variable as VariableT +from tinygrad.dtype import dtypes +from tinygrad.ops import UOp, Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int]): + def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) # Replace input buffers with variables. self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache] - self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {} + self.input_replace_to_var: Dict[Tuple[int, int], Variable] = {} for (j,i), input_idx in self.input_replace.items(): - x = self.input_replace_to_var.setdefault((j,i), Variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64)) + x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64)) self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable # Allocate kernel args. @@ -49,7 +49,7 @@ class HCQGraph(MultiGraphRunner): self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}} self.kickoff_value: int = 0 - self.kickoff_var = Variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32) + self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32) self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else [] self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = [] @@ -118,8 +118,8 @@ class HCQGraph(MultiGraphRunner): self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices} # Create variable timeline signals for each device. - timeline_sigaddrs = {dev: Variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices} - self.virt_timeline_vals = {dev: Variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices} + timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices} + self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices} self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices} for dev in self.devices: @@ -159,7 +159,7 @@ class HCQGraph(MultiGraphRunner): self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int], wait=False) -> Optional[float]: + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: # Wait and restore signals self.kickoff_value += 1 for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])