hcq replace update with sint (#7899)

* try sym hcq

* start with amd

* move to nv

* nv works

* cache and qcom

* fixes

* signals

* fix nv

* qcom fixes

* linter

* linter

* cache + typings

* fixes

* tiny fixes

* linter

* linter

* lntr

* ugh

* comments
This commit is contained in:
nimlgen
2024-11-29 20:08:13 +03:00
committed by GitHub
parent aa51f3c14e
commit 10f431b96d
7 changed files with 261 additions and 358 deletions

View File

@@ -6,6 +6,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner, CompiledRunner
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad import Variable
MOCKGPU = getenv("MOCKGPU")
@@ -44,14 +45,19 @@ class TestHCQ(unittest.TestCase):
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
if queue_type is None: continue
with self.subTest(name=str(queue_type)):
q = queue_type().signal(TestHCQ.d0.signal_t(), 0x1000)
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
q.update_signal(0, signal=TestHCQ.d0.timeline_signal, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
with self.subTest(name=str(queue_type)):
q = queue_type().signal(virt_signal, virt_val)
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
q.update_signal(0, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -91,12 +97,15 @@ class TestHCQ(unittest.TestCase):
if queue_type is None: continue
with self.subTest(name=str(queue_type)):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
fake_signal = TestHCQ.d0.signal_t()
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
fake_signal.value = 0x30
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -112,26 +121,30 @@ class TestHCQ(unittest.TestCase):
assert val == 1.0, f"got val {val}"
def test_exec_2_kernels_100_times(self):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
q = TestHCQ.d0.hw_compute_queue_t()
q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
.signal(TestHCQ.d0.timeline_signal, virt_val)
for _ in range(100):
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 200.0, f"got val {val}"
def test_exec_update(self):
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
q = TestHCQ.d0.hw_compute_queue_t()
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.update_exec(0, (1,1,1), (1,1,1))
q.submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -141,6 +154,9 @@ class TestHCQ(unittest.TestCase):
assert val == 0.0, f"got val {val}, should not be updated"
def test_exec_update_fuzz(self):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
b = a + 1
si = create_schedule([b.lazydata])[-1]
@@ -156,16 +172,15 @@ class TestHCQ(unittest.TestCase):
q = TestHCQ.d0.hw_compute_queue_t()
q.memory_barrier() \
.exec(runner._prg, kernargs, (1,1,1), (1,1,1)) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
.exec(runner._prg, kernargs, (1,1,1), virt_local) \
.signal(TestHCQ.d0.timeline_signal, virt_val)
for x in range(1, 4):
for y in range(1, 4):
for z in range(1, 4):
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
q.update_exec(1, local_size=(x,y,z)) \
.update_signal(2, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -207,12 +222,14 @@ class TestHCQ(unittest.TestCase):
def test_update_copy(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
.copy(0x0, 0x0, 8) \
.copy(virt_dest_addr, virt_src_addr, 8) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.update_copy(1, dest=TestHCQ.b.lazydata.buffer._buf.va_addr, src=TestHCQ.a.lazydata.buffer._buf.va_addr) \
.submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -223,17 +240,19 @@ class TestHCQ(unittest.TestCase):
def test_update_copy_long(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
sz = 64 << 20
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
ctypes.memset(buf2._buf.va_addr, 1, sz)
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
.copy(0x0, 0x0, sz) \
.copy(virt_dest_addr, virt_src_addr, sz) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.update_copy(1, buf1._buf.va_addr, buf2._buf.va_addr) \
.submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -246,14 +265,17 @@ class TestHCQ(unittest.TestCase):
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
if queue_type is None: continue
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
with self.subTest(name=str(queue_type)):
fake_signal = TestHCQ.d0.signal_t()
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.bind(TestHCQ.d0)
fake_signal.value = 0x30
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1