address kfd feedback (#4087)

* address kfd feedback

* signals cleanup

* signals cleanup

* handle 2 doorbell pages correctly

* signal reset cleanup

* signals cleanup

* more GTT

* cleanups

* minor cleanups
This commit is contained in:
George Hotz
2024-04-05 15:24:41 -07:00
committed by GitHub
parent dafa42e864
commit 164329a8ea
2 changed files with 103 additions and 107 deletions

View File

@@ -8,7 +8,7 @@ def _time_queue(q, d):
st = time.perf_counter()
q.signal(d.completion_signal)
q.submit(d)
d._wait_on(d.completion_signal.event_id)
d._wait_signal(d.completion_signal)
return time.perf_counter() - st
class TestHCQ(unittest.TestCase):
@@ -38,7 +38,7 @@ class TestHCQ(unittest.TestCase):
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+len(TestHCQ.addr), TestHCQ.runner.global_size, TestHCQ.runner.local_size)
q.signal(TestHCQ.d0.completion_signal)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
def test_run_1000_times(self):
@@ -48,10 +48,10 @@ class TestHCQ(unittest.TestCase):
TestHCQ.runner.local_size, TestHCQ.d0.completion_signal)
for _ in range(1000):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
# confirm signal was reset
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=50)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=50)
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
def test_run_to_3(self):
@@ -60,32 +60,32 @@ class TestHCQ(unittest.TestCase):
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+len(TestHCQ.addr), TestHCQ.runner.global_size, TestHCQ.runner.local_size)
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size, TestHCQ.d0.completion_signal)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}"
def test_wait_signal(self):
TestHCQ.d0.completion_signal.value = 1
HWComputeQueue().wait(TestHCQ.d0.completion_signal).signal(TestHCQ.d0.completion_signal).submit(TestHCQ.d0)
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=50)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=50)
# clean up
TestHCQ.d0.completion_signal.value = 0
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=1000)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=1000)
def test_wait_copy_signal(self):
TestHCQ.d0.completion_signal.value = 1
HWCopyQueue().wait(TestHCQ.d0.completion_signal).signal(TestHCQ.d0.completion_signal).submit(TestHCQ.d0)
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=50)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=50)
# clean up
TestHCQ.d0.completion_signal.value = 0
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=1000)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=1000)
def test_run_normal(self):
q = HWComputeQueue()
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size, TestHCQ.d0.completion_signal)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
def test_submit_empty_queues(self):
@@ -94,22 +94,22 @@ class TestHCQ(unittest.TestCase):
def test_signal_timeout(self):
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=50)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=50)
def test_signal(self):
HWComputeQueue().signal(TestHCQ.d0.completion_signal).submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
def test_copy_signal(self):
HWCopyQueue().signal(TestHCQ.d0.completion_signal).submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
def test_run_signal(self):
q = HWComputeQueue()
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.global_size, TestHCQ.runner.local_size)
q.signal(TestHCQ.d0.completion_signal)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
def test_copy_1000_times(self):
@@ -119,10 +119,10 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.completion_signal)
for _ in range(1000):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
# confirm signal was reset
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id, timeout=50)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal, timeout=50)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
def test_copy(self):
@@ -130,7 +130,7 @@ class TestHCQ(unittest.TestCase):
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
q.signal(TestHCQ.d0.completion_signal)
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
def test_copy_bandwidth(self):
@@ -169,7 +169,7 @@ class TestHCQ(unittest.TestCase):
qc.submit(TestHCQ.d0)
time.sleep(0.02) # give it time for the wait to fail
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
def test_cross_device_signal(self):
@@ -179,7 +179,7 @@ class TestHCQ(unittest.TestCase):
q2.wait(TestHCQ.d0.completion_signal)
q2.submit(TestHCQ.d0)
q1.submit(TestHCQ.d1)
TestHCQ.d0._wait_on(TestHCQ.d0.completion_signal.event_id)
TestHCQ.d0._wait_signal(TestHCQ.d0.completion_signal)
if __name__ == "__main__":
unittest.main()