mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
add hcq fixedvars support [pr] (#10356)
* add hcq fixedvars support [pr] * different test * fixedvars are only for comp_queues * fix hcq varvals
This commit is contained in:
@@ -218,6 +218,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||
out = f(tt)
|
||||
assert out.item() == 1+2+3+4
|
||||
|
||||
def test_multitensor_inside_jit(self):
|
||||
@TinyJit
|
||||
def f(x): return (x.shard((d1,d2), 0)+1).contiguous().sum()
|
||||
for _ in range(5):
|
||||
tt = Tensor.arange(0, 4).contiguous().realize()
|
||||
out = f(tt)
|
||||
assert out.item() == 1+2+3+4
|
||||
|
||||
@unittest.skip("slow")
|
||||
def test_fuzz_allreduce(self):
|
||||
random.seed(41)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import collections, time
|
||||
from typing import Any, cast
|
||||
from tinygrad.helpers import round_up, PROFILE
|
||||
from tinygrad.helpers import round_up, PROFILE, merge_dicts
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -64,9 +64,14 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
|
||||
# set any fixedvars on the device
|
||||
self.fixedvars[enqueue_dev] = merge_dicts([self.fixedvars.get(enqueue_dev, {}), ji.fixedvars])
|
||||
|
||||
if is_exec_prg:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
else:
|
||||
@@ -182,8 +187,8 @@ class HCQGraph(MultiGraphRunner):
|
||||
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals)
|
||||
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
||||
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
|
||||
|
||||
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user