mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hcq replace update with sint (#7899)
* try sym hcq * start with amd * move to nv * nv works * cache and qcom * fixes * signals * fix nv * qcom fixes * linter * linter * cache + typings * fixes * tiny fixes * linter * linter * lntr * ugh * comments
This commit is contained in:
@@ -6,6 +6,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad import Variable
|
||||
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
@@ -44,14 +45,19 @@ class TestHCQ(unittest.TestCase):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
q = queue_type().signal(TestHCQ.d0.signal_t(), 0x1000)
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
|
||||
|
||||
q.update_signal(0, signal=TestHCQ.d0.timeline_signal, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
with self.subTest(name=str(queue_type)):
|
||||
q = queue_type().signal(virt_signal, virt_val)
|
||||
|
||||
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
|
||||
q.submit(TestHCQ.d0, var_vals)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
q.update_signal(0, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
|
||||
q.submit(TestHCQ.d0, var_vals)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -91,12 +97,15 @@ class TestHCQ(unittest.TestCase):
|
||||
if queue_type is None: continue
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
|
||||
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
fake_signal.value = 0x30
|
||||
|
||||
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -112,26 +121,30 @@ class TestHCQ(unittest.TestCase):
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_exec_2_kernels_100_times(self):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for _ in range(100):
|
||||
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 200.0, f"got val {val}"
|
||||
|
||||
def test_exec_update(self):
|
||||
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
|
||||
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.update_exec(0, (1,1,1), (1,1,1))
|
||||
q.submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -141,6 +154,9 @@ class TestHCQ(unittest.TestCase):
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
def test_exec_update_fuzz(self):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
|
||||
|
||||
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
|
||||
b = a + 1
|
||||
si = create_schedule([b.lazydata])[-1]
|
||||
@@ -156,16 +172,15 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.memory_barrier() \
|
||||
.exec(runner._prg, kernargs, (1,1,1), (1,1,1)) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
.exec(runner._prg, kernargs, (1,1,1), virt_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for x in range(1, 4):
|
||||
for y in range(1, 4):
|
||||
for z in range(1, 4):
|
||||
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
|
||||
|
||||
q.update_exec(1, local_size=(x,y,z)) \
|
||||
.update_signal(2, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -207,12 +222,14 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_update_copy(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
|
||||
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
.copy(0x0, 0x0, 8) \
|
||||
.copy(virt_dest_addr, virt_src_addr, 8) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.update_copy(1, dest=TestHCQ.b.lazydata.buffer._buf.va_addr, src=TestHCQ.a.lazydata.buffer._buf.va_addr) \
|
||||
.submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -223,17 +240,19 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_update_copy_long(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
|
||||
sz = 64 << 20
|
||||
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
ctypes.memset(buf2._buf.va_addr, 1, sz)
|
||||
|
||||
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
.copy(0x0, 0x0, sz) \
|
||||
.copy(virt_dest_addr, virt_src_addr, sz) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.update_copy(1, buf1._buf.va_addr, buf2._buf.va_addr) \
|
||||
.submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -246,14 +265,17 @@ class TestHCQ(unittest.TestCase):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.bind(TestHCQ.d0)
|
||||
|
||||
fake_signal.value = 0x30
|
||||
|
||||
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user