var_vals uses str for var (#12011)

* var_vals is str,int

* remove imports

* remove print

* fix test

* change var_vals in hcq

* update test_hcq

* fix multitensor _device_num var

* fix syminfer test

* shorten line

* p.vars stays list[Variable]

* shorten line

* vars is back to tuple[Variable, ...]

* change var_vals in extra

* change var_vals from shapetracker

* var_vals is str:int

* fix signature
This commit is contained in:
Sieds Lykles
2025-09-06 04:16:12 +02:00
committed by GitHub
parent 8658a97197
commit c6c16b2946
24 changed files with 90 additions and 91 deletions

View File

@@ -52,12 +52,12 @@ class TestHCQ(unittest.TestCase):
with self.subTest(name=str(queue_type)):
q = queue_type().signal(virt_signal, virt_val)
var_vals = {virt_signal.base_buf.va_addr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val: TestHCQ.d0.timeline_value}
var_vals = {virt_signal.base_buf.va_addr.expr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val.expr: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
var_vals = {virt_signal.base_buf.va_addr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val: TestHCQ.d0.timeline_value}
var_vals = {virt_signal.base_buf.va_addr.expr: TestHCQ.d0.timeline_signal.base_buf.va_addr, virt_val.expr: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -106,7 +106,7 @@ class TestHCQ(unittest.TestCase):
fake_signal.value = 0x30
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr: fake_signal.base_buf.va_addr, virt_val: fake_signal.value})
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr.expr: fake_signal.base_buf.va_addr, virt_val.expr: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -131,7 +131,7 @@ class TestHCQ(unittest.TestCase):
.signal(TestHCQ.d0.timeline_signal, virt_val)
for _ in range(100):
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value})
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
@@ -146,7 +146,7 @@ class TestHCQ(unittest.TestCase):
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
q.submit(TestHCQ.d0, {sint_global[0].expr: 1, sint_local[0].expr: 1})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -181,7 +181,7 @@ class TestHCQ(unittest.TestCase):
for z in range(1, 4):
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value, virt_local[0].expr: x, virt_local[1].expr: y, virt_local[2].expr: z})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -253,7 +253,7 @@ class TestHCQ(unittest.TestCase):
.copy(virt_dest_addr, virt_src_addr, 8) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.uop.buffer._buf.va_addr})
q.submit(TestHCQ.d0, {virt_src_addr.expr: TestHCQ.a.uop.buffer._buf.va_addr, virt_dest_addr.expr: TestHCQ.b.uop.buffer._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -276,7 +276,7 @@ class TestHCQ(unittest.TestCase):
.copy(virt_dest_addr, virt_src_addr, sz) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
q.submit(TestHCQ.d0, {virt_src_addr.expr: buf2._buf.va_addr, virt_dest_addr.expr: buf1._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -299,7 +299,7 @@ class TestHCQ(unittest.TestCase):
fake_signal.value = 0x30
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr: fake_signal.base_buf.va_addr, virt_val: fake_signal.value})
q.submit(TestHCQ.d0, {virt_signal.base_buf.va_addr.expr: fake_signal.base_buf.va_addr, virt_val.expr: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1