viz: variable duration rdna barriers (#15277)

* viz: variable length rdna barriers

* work

* tiny changes

* simple wave simd test

* small wave sync test

* good multi barrier bug find

* simple fix

* wave_sync asserts

* rdna4 work

* more rdna4

* find more bugs in my model

* it's so much simpler

* wave_sync tests duration

* r4

* should just call this rdna4
This commit is contained in:
qazal
2026-03-15 23:06:19 +02:00
committed by GitHub
parent 5cd1daa3bc
commit 4445f50356
10 changed files with 49 additions and 7 deletions

View File

@@ -9,6 +9,7 @@ EXAMPLES = {
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
"plus":"test/test_tiny.py TestTiny.test_plus",
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_wave_sync",
}
if __name__ == "__main__":

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -6,7 +6,7 @@ from tinygrad.runtime.support.elf import elf_loader
ARCH_TO_TARGET:dict[str, list[str]] = {
"rdna3":["gfx1100"],
"rdna4":["gfx1200"],
"rdna4":["gfx1200", "gfx1201"],
"cdna":["gfx950", "gfx942"],
}

View File

@@ -1,10 +1,12 @@
import unittest
import functools
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.runtime.autogen.amd.rdna3.ins import *
from tinygrad.runtime.autogen.amd.rdna4.ins import s_barrier_wait, s_barrier_signal
from tinygrad.renderer.amd.dsl import s, v
from test.amd.helpers import TARGET_TO_ARCH
def custom_add_one(A:UOp) -> UOp:
A = A.flatten()
@@ -43,9 +45,26 @@ def custom_add_var(A:UOp, B:UOp) -> UOp:
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.size}"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_wave_sync(A:UOp, arch:str) -> UOp:
# 4 waves across 1024 WG — enough to saturate a SIMD with many concurrent WGs
# s_sleep yields the SIMD so waves from different WGs interleave, causing barrier packet reordering
threads = UOp.special(128, "lidx0")
wg = UOp.special(1024, "gidx0")
insts = []
for _ in range(4):
insts.append(s_sleep(4))
insts += [s_barrier()] if arch == "rdna3" else [s_barrier_signal(), s_barrier_wait()]
insts += [s_nop(0)]*4
insts.append(s_endpgm())
sink = UOp.sink(A.base, threads, wg, arg=KernelInfo("custom_wave_sync"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
class TestCustomKernel(unittest.TestCase):
def setUp(self): self.arch = TARGET_TO_ARCH[Device["AMD"].arch]
def test_simple(self):
if self.arch != "rdna3": self.skipTest("only rdna3")
a = Tensor.full((16, 16), 1.).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=custom_add_one)[0]
ei = a.schedule()[-1].lower()
@@ -55,6 +74,7 @@ class TestCustomKernel(unittest.TestCase):
self.assertTrue((a.numpy() == 2.).all())
def test_variable(self):
if self.arch != "rdna3": self.skipTest("only rdna3")
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
a = Tensor.zeros_like(b).contiguous().realize()
a = Tensor.custom_kernel(a, b, fxn=custom_add_var)[0]
@@ -63,5 +83,9 @@ class TestCustomKernel(unittest.TestCase):
ei.run({"var":i})
self.assertTrue((a.numpy() == 1+i).all())
def test_wave_sync(self):
if self.arch not in {"rdna3", "rdna4"}: self.skipTest("only rdna3 or rdna4")
Tensor.empty(1).custom_kernel(fxn=functools.partial(custom_wave_sync, arch=self.arch))[0].realize()
if __name__ == "__main__":
unittest.main()

View File

@@ -64,6 +64,8 @@ class TestSQTTMapBase(unittest.TestCase):
def test_rocprof_inst_traces_match(self):
for name, (events, kern_events, target) in self.examples.items():
if "sync" in name and self.target.startswith("gfx12"):
self.skipTest("our timestamps are off by a few cycles because rocprof patches timestamps for rdna4 barriers")
for event in events:
if not event.itrace: continue
if event.kern not in kern_events: continue
@@ -81,7 +83,6 @@ class TestSQTTMapBase(unittest.TestCase):
mean = sum(frequency) / len(frequency)
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
self.assertGreater(mean, 0)
self.assertGreater(variance, 0)
if DEBUG >= 2: print(f"{name:20s} SE:{event.se} {mean/1e9:.2f} GHz mean, {variance/1e18:.2f} GHz^2 variance")
events = [e for e in timeline if type(e).__name__ == "ProfileRangeEvent"]
insts, execs = 0, 0
@@ -90,10 +91,21 @@ class TestSQTTMapBase(unittest.TestCase):
if "ALT" not in e.name.display_name: execs += 1
elif "WAVE" in e.device:
# sopk/immediates don't get ALU/MEM EXEC
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE"}: insts += 1
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL"}: insts += 1
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
self.assertEqual(execs, insts)
def test_wave_sync(self):
for name, (events, kern_events, target) in self.examples.items():
for event in events:
wave_barriers = {}
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
if not wave_barriers: continue
for row, events in wave_barriers.items():
for e in events:
assert e.en-e.st > 1, f"all barriers must have a duration greater than 1, got {e}"
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"

View File

@@ -128,6 +128,7 @@ class InstOpRDNA4(Enum):
LDS_WR_5 = 0x2e
OTHER_LDS_1 = 0x50
OTHER_LDS_2 = 0x51
BARRIER_SIGNAL = 0x7a
WMMA_8 = 0x8c
WMMA_16 = 0x8d
VALU_DPFP = 0x92

View File

@@ -342,13 +342,17 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4
ret:list[ProfileEvent] = []
row_ends:dict[str, Decimal] = {}
curr_barrier:dict[str, ProfileRangeEvent] = {}
NS_PER_TICK = 10 # 100MHz
prev_pair:tuple[int, int]|None = None # (shader, realtime)
def add(name:str, p:PacketType, width=1, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None:
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None:
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+width)))
# barrier on this row extends to fill the time our wave was waiting
if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time)
ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1)))
if (et:=row_ends.get(row)) is not None and e.st < et: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
row_ends[row] = unwrap(e.en)
if name == "BARRIER": curr_barrier[row] = e
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
@@ -361,7 +365,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
prev_pair = pair
if isinstance(p, (INST, INST_RDNA4)):
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}"
add(name, p, width=10 if "BARRIER" in name else 1, info=info)
add(name, p, info=info)
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)
if isinstance(p, (VMEMEXEC, ALUEXEC)):