mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
sqtt test_timing work (#13304)
* sqtt test_timing cleanups * only the instruction * v_mfma_f32_16x16x32_f16 16 cycles, only after second one though
This commit is contained in:
@@ -39,10 +39,10 @@ def save_sqtt():
|
||||
sqtt:dict[str, list[WaveExec]] = {}
|
||||
yield sqtt
|
||||
# decode sqtt
|
||||
if os.environ["DEV"] == "AMD":
|
||||
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
|
||||
assert len(rctx.inst_execs) > 0, "empty sqtt output"
|
||||
sqtt.update(rctx.inst_execs)
|
||||
if os.environ["DEV"] != "AMD": return
|
||||
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
|
||||
assert len(rctx.inst_execs) > 0, "empty sqtt output"
|
||||
sqtt.update(rctx.inst_execs)
|
||||
|
||||
class TestTiming(unittest.TestCase):
|
||||
def test_v_add(self):
|
||||
@@ -84,13 +84,17 @@ class TestTiming(unittest.TestCase):
|
||||
|
||||
def test_wmma(self):
|
||||
with save_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_wmma_f32_16x16x16_f16 v[16:23], v[0:7], v[8:15], v[16:23]",
|
||||
"v_add_f32_e32 v0 v16 v0",
|
||||
], l=32*4).realize()
|
||||
assert len(sqtt) == 2, f"expected two waves, got {len(sqtt)} {list(sqtt.keys())}"
|
||||
wmma = list(sqtt.values())[0][0]
|
||||
self.assertGreater(wmma.dur, 1) # rgp says 32 clocks
|
||||
for tc in dev.renderer.get_tensor_cores(dev.arch):
|
||||
M, K, N = tc.dims
|
||||
s = 32
|
||||
a = Tensor.empty(M*s, K*s, dtype=tc.dtype_in)@Tensor.empty(K*s, N*s, dtype=tc.dtype_in)
|
||||
a.realize()
|
||||
print(a)
|
||||
for p,waves in sqtt.items():
|
||||
for e in waves[0].insts:
|
||||
if (e.inst.startswith("v_wmma")):
|
||||
instruction = e.inst.split(" ")[0]
|
||||
print(f"{instruction:<29} : {e.dur} cycles")
|
||||
|
||||
def test_sleep(self):
|
||||
n = 1
|
||||
@@ -105,8 +109,9 @@ class TestTiming(unittest.TestCase):
|
||||
diff_hw_reg = Tensor.custom_kernel(diff_hw_reg, fxn=sleep_kernel)[0]
|
||||
with save_sqtt() as sqtt:
|
||||
diff_hw_reg.realize()
|
||||
diff_sqtt = list(sqtt.values())[0][2]
|
||||
self.assertEqual(diff_sqtt.dur, diff_hw_reg.item()-1) # 1 cycle for reading the counter register
|
||||
sleep = next((e for e in sqtt[f"sleep_{n}"][0].insts if e.inst.startswith("s_sleep")))
|
||||
# cycles = sleep dur + overhead of storing hi/lo REG_SHADER_CYCLES
|
||||
self.assertGreaterEqual(diff_hw_reg.item(), sleep.dur)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user