diff --git a/extra/sqtt/examples/generate_examples.py b/extra/sqtt/examples/generate_examples.py index fcdc301409..92146dfb20 100644 --- a/extra/sqtt/examples/generate_examples.py +++ b/extra/sqtt/examples/generate_examples.py @@ -9,6 +9,7 @@ EXAMPLES = { "empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty", "plus":"test/test_tiny.py TestTiny.test_plus", "gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"", + "sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_wave_sync", } if __name__ == "__main__": diff --git a/extra/sqtt/examples/gfx1100/profile_sync_run_0.pkl b/extra/sqtt/examples/gfx1100/profile_sync_run_0.pkl new file mode 100644 index 0000000000..a6292729b5 Binary files /dev/null and b/extra/sqtt/examples/gfx1100/profile_sync_run_0.pkl differ diff --git a/extra/sqtt/examples/gfx1100/profile_sync_run_1.pkl b/extra/sqtt/examples/gfx1100/profile_sync_run_1.pkl new file mode 100644 index 0000000000..605e444a68 Binary files /dev/null and b/extra/sqtt/examples/gfx1100/profile_sync_run_1.pkl differ diff --git a/extra/sqtt/examples/gfx1200/profile_sync_run_0.pkl b/extra/sqtt/examples/gfx1200/profile_sync_run_0.pkl new file mode 100644 index 0000000000..98c36ea5ec Binary files /dev/null and b/extra/sqtt/examples/gfx1200/profile_sync_run_0.pkl differ diff --git a/extra/sqtt/examples/gfx1200/profile_sync_run_1.pkl b/extra/sqtt/examples/gfx1200/profile_sync_run_1.pkl new file mode 100644 index 0000000000..42ac0c1f63 Binary files /dev/null and b/extra/sqtt/examples/gfx1200/profile_sync_run_1.pkl differ diff --git a/test/amd/helpers.py b/test/amd/helpers.py index 05c983539c..367268924e 100644 --- a/test/amd/helpers.py +++ b/test/amd/helpers.py @@ -6,7 +6,7 @@ from tinygrad.runtime.support.elf import elf_loader ARCH_TO_TARGET:dict[str, list[str]] = { "rdna3":["gfx1100"], - "rdna4":["gfx1200"], + "rdna4":["gfx1200", "gfx1201"], "cdna":["gfx950", "gfx942"], } diff --git a/test/amd/test_custom_kernel.py b/test/amd/test_custom_kernel.py index 1e32979d59..21f1042c20 100644 --- a/test/amd/test_custom_kernel.py +++ b/test/amd/test_custom_kernel.py @@ -1,10 +1,12 @@ import unittest +import functools from tinygrad import Tensor, Device, dtypes from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.renderer import Estimates - from tinygrad.runtime.autogen.amd.rdna3.ins import * +from tinygrad.runtime.autogen.amd.rdna4.ins import s_barrier_wait, s_barrier_signal from tinygrad.renderer.amd.dsl import s, v +from test.amd.helpers import TARGET_TO_ARCH def custom_add_one(A:UOp) -> UOp: A = A.flatten() @@ -43,9 +45,26 @@ def custom_add_var(A:UOp, B:UOp) -> UOp: sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.size}")) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) +def custom_wave_sync(A:UOp, arch:str) -> UOp: + # 4 waves across 1024 WG — enough to saturate a SIMD with many concurrent WGs + # s_sleep yields the SIMD so waves from different WGs interleave, causing barrier packet reordering + threads = UOp.special(128, "lidx0") + wg = UOp.special(1024, "gidx0") + insts = [] + for _ in range(4): + insts.append(s_sleep(4)) + insts += [s_barrier()] if arch == "rdna3" else [s_barrier_signal(), s_barrier_wait()] + insts += [s_nop(0)]*4 + insts.append(s_endpgm()) + sink = UOp.sink(A.base, threads, wg, arg=KernelInfo("custom_wave_sync")) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) + @unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device") class TestCustomKernel(unittest.TestCase): + def setUp(self): self.arch = TARGET_TO_ARCH[Device["AMD"].arch] + def test_simple(self): + if self.arch != "rdna3": self.skipTest("only rdna3") a = Tensor.full((16, 16), 1.).contiguous().realize() a = Tensor.custom_kernel(a, fxn=custom_add_one)[0] ei = a.schedule()[-1].lower() @@ -55,6 +74,7 @@ class TestCustomKernel(unittest.TestCase): self.assertTrue((a.numpy() == 2.).all()) def test_variable(self): + if self.arch != "rdna3": self.skipTest("only rdna3") b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize() a = Tensor.zeros_like(b).contiguous().realize() a = Tensor.custom_kernel(a, b, fxn=custom_add_var)[0] @@ -63,5 +83,9 @@ class TestCustomKernel(unittest.TestCase): ei.run({"var":i}) self.assertTrue((a.numpy() == 1+i).all()) + def test_wave_sync(self): + if self.arch not in {"rdna3", "rdna4"}: self.skipTest("only rdna3 or rdna4") + Tensor.empty(1).custom_kernel(fxn=functools.partial(custom_wave_sync, arch=self.arch))[0].realize() + if __name__ == "__main__": unittest.main() diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index f5bf398ccb..eec71b8329 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -64,6 +64,8 @@ class TestSQTTMapBase(unittest.TestCase): def test_rocprof_inst_traces_match(self): for name, (events, kern_events, target) in self.examples.items(): + if "sync" in name and self.target.startswith("gfx12"): + self.skipTest("our timestamps are off by a few cycles because rocprof patches timestamps for rdna4 barriers") for event in events: if not event.itrace: continue if event.kern not in kern_events: continue @@ -81,7 +83,6 @@ class TestSQTTMapBase(unittest.TestCase): mean = sum(frequency) / len(frequency) variance = sum((v - mean) ** 2 for v in frequency) / len(frequency) self.assertGreater(mean, 0) - self.assertGreater(variance, 0) if DEBUG >= 2: print(f"{name:20s} SE:{event.se} {mean/1e9:.2f} GHz mean, {variance/1e18:.2f} GHz^2 variance") events = [e for e in timeline if type(e).__name__ == "ProfileRangeEvent"] insts, execs = 0, 0 @@ -90,10 +91,21 @@ class TestSQTTMapBase(unittest.TestCase): if "ALT" not in e.name.display_name: execs += 1 elif "WAVE" in e.device: # sopk/immediates don't get ALU/MEM EXEC - if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE"}: insts += 1 + if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL"}: insts += 1 else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}") self.assertEqual(execs, insts) + def test_wave_sync(self): + for name, (events, kern_events, target) in self.examples.items(): + for event in events: + wave_barriers = {} + for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target): + if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e) + if not wave_barriers: continue + for row, events in wave_barriers.items(): + for e in events: + assert e.en-e.st > 1, f"all barriers must have a duration greater than 1, got {e}" + class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100" class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200" diff --git a/tinygrad/renderer/amd/sqtt.py b/tinygrad/renderer/amd/sqtt.py index b89ea90fdb..43a1ce62ac 100644 --- a/tinygrad/renderer/amd/sqtt.py +++ b/tinygrad/renderer/amd/sqtt.py @@ -128,6 +128,7 @@ class InstOpRDNA4(Enum): LDS_WR_5 = 0x2e OTHER_LDS_1 = 0x50 OTHER_LDS_2 = 0x51 + BARRIER_SIGNAL = 0x7a WMMA_8 = 0x8c WMMA_16 = 0x8d VALU_DPFP = 0x92 diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index d051a17c8a..146cb919ae 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -342,13 +342,17 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4 ret:list[ProfileEvent] = [] row_ends:dict[str, Decimal] = {} + curr_barrier:dict[str, ProfileRangeEvent] = {} NS_PER_TICK = 10 # 100MHz prev_pair:tuple[int, int]|None = None # (shader, realtime) - def add(name:str, p:PacketType, width=1, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None: + def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None: row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}" - ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+width))) + # barrier on this row extends to fill the time our wave was waiting + if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time) + ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1))) if (et:=row_ends.get(row)) is not None and e.st < et: raise RuntimeError(f"packet {p} overlaps another packet in {row}.") row_ends[row] = unwrap(e.en) + if name == "BARRIER": curr_barrier[row] = e for p, info in map_insts(data, lib, target): if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker: @@ -361,7 +365,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: prev_pair = pair if isinstance(p, (INST, INST_RDNA4)): name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}" - add(name, p, width=10 if "BARRIER" in name else 1, info=info) + add(name, p, info=info) if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info) if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info) if isinstance(p, (VMEMEXEC, ALUEXEC)):