add hcq fixedvars support [pr] (#10356)

* add hcq fixedvars support [pr]

* different test

* fixedvars are only for comp_queues

* fix hcq varvals
This commit is contained in:
George Hotz
2025-05-16 22:05:53 -07:00
committed by GitHub
parent 11b5895c85
commit e1a40e8040
2 changed files with 16 additions and 3 deletions

View File

@@ -218,6 +218,14 @@ class TestMultiTensor(unittest.TestCase):
out = f(tt)
assert out.item() == 1+2+3+4
def test_multitensor_inside_jit(self):
@TinyJit
def f(x): return (x.shard((d1,d2), 0)+1).contiguous().sum()
for _ in range(5):
tt = Tensor.arange(0, 4).contiguous().realize()
out = f(tt)
assert out.item() == 1+2+3+4
@unittest.skip("slow")
def test_fuzz_allreduce(self):
random.seed(41)

View File

@@ -1,6 +1,6 @@
import collections, time
from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE
from tinygrad.helpers import round_up, PROFILE, merge_dicts
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes
@@ -64,9 +64,14 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
for j,ji in enumerate(jit_cache):
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
# set any fixedvars on the device
self.fixedvars[enqueue_dev] = merge_dicts([self.fixedvars.get(enqueue_dev, {}), ji.fixedvars])
if is_exec_prg:
enqueue_queue = self.comp_queues[enqueue_dev]
else:
@@ -182,8 +187,8 @@ class HCQGraph(MultiGraphRunner):
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals)
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())