mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: variable duration rdna barriers (#15277)
* viz: variable length rdna barriers * work * tiny changes * simple wave simd test * small wave sync test * good multi barrier bug find * simple fix * wave_sync asserts * rdna4 work * more rdna4 * find more bugs in my model * it's so much simpler * wave_sync tests duration * r4 * should just call this rdna4
This commit is contained in:
@@ -9,6 +9,7 @@ EXAMPLES = {
|
||||
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
|
||||
"plus":"test/test_tiny.py TestTiny.test_plus",
|
||||
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
|
||||
"sync":"test/amd/test_custom_kernel.py TestCustomKernel.test_wave_sync",
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
BIN
extra/sqtt/examples/gfx1100/profile_sync_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_sync_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1100/profile_sync_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1100/profile_sync_run_1.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_sync_run_0.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_sync_run_0.pkl
Normal file
Binary file not shown.
BIN
extra/sqtt/examples/gfx1200/profile_sync_run_1.pkl
Normal file
BIN
extra/sqtt/examples/gfx1200/profile_sync_run_1.pkl
Normal file
Binary file not shown.
@@ -6,7 +6,7 @@ from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
ARCH_TO_TARGET:dict[str, list[str]] = {
|
||||
"rdna3":["gfx1100"],
|
||||
"rdna4":["gfx1200"],
|
||||
"rdna4":["gfx1200", "gfx1201"],
|
||||
"cdna":["gfx950", "gfx942"],
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import unittest
|
||||
import functools
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
from tinygrad.runtime.autogen.amd.rdna4.ins import s_barrier_wait, s_barrier_signal
|
||||
from tinygrad.renderer.amd.dsl import s, v
|
||||
from test.amd.helpers import TARGET_TO_ARCH
|
||||
|
||||
def custom_add_one(A:UOp) -> UOp:
|
||||
A = A.flatten()
|
||||
@@ -43,9 +45,26 @@ def custom_add_var(A:UOp, B:UOp) -> UOp:
|
||||
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.size}"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
def custom_wave_sync(A:UOp, arch:str) -> UOp:
|
||||
# 4 waves across 1024 WG — enough to saturate a SIMD with many concurrent WGs
|
||||
# s_sleep yields the SIMD so waves from different WGs interleave, causing barrier packet reordering
|
||||
threads = UOp.special(128, "lidx0")
|
||||
wg = UOp.special(1024, "gidx0")
|
||||
insts = []
|
||||
for _ in range(4):
|
||||
insts.append(s_sleep(4))
|
||||
insts += [s_barrier()] if arch == "rdna3" else [s_barrier_signal(), s_barrier_wait()]
|
||||
insts += [s_nop(0)]*4
|
||||
insts.append(s_endpgm())
|
||||
sink = UOp.sink(A.base, threads, wg, arg=KernelInfo("custom_wave_sync"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
|
||||
class TestCustomKernel(unittest.TestCase):
|
||||
def setUp(self): self.arch = TARGET_TO_ARCH[Device["AMD"].arch]
|
||||
|
||||
def test_simple(self):
|
||||
if self.arch != "rdna3": self.skipTest("only rdna3")
|
||||
a = Tensor.full((16, 16), 1.).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, fxn=custom_add_one)[0]
|
||||
ei = a.schedule()[-1].lower()
|
||||
@@ -55,6 +74,7 @@ class TestCustomKernel(unittest.TestCase):
|
||||
self.assertTrue((a.numpy() == 2.).all())
|
||||
|
||||
def test_variable(self):
|
||||
if self.arch != "rdna3": self.skipTest("only rdna3")
|
||||
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
|
||||
a = Tensor.zeros_like(b).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, b, fxn=custom_add_var)[0]
|
||||
@@ -63,5 +83,9 @@ class TestCustomKernel(unittest.TestCase):
|
||||
ei.run({"var":i})
|
||||
self.assertTrue((a.numpy() == 1+i).all())
|
||||
|
||||
def test_wave_sync(self):
|
||||
if self.arch not in {"rdna3", "rdna4"}: self.skipTest("only rdna3 or rdna4")
|
||||
Tensor.empty(1).custom_kernel(fxn=functools.partial(custom_wave_sync, arch=self.arch))[0].realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -64,6 +64,8 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
|
||||
def test_rocprof_inst_traces_match(self):
|
||||
for name, (events, kern_events, target) in self.examples.items():
|
||||
if "sync" in name and self.target.startswith("gfx12"):
|
||||
self.skipTest("our timestamps are off by a few cycles because rocprof patches timestamps for rdna4 barriers")
|
||||
for event in events:
|
||||
if not event.itrace: continue
|
||||
if event.kern not in kern_events: continue
|
||||
@@ -81,7 +83,6 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
mean = sum(frequency) / len(frequency)
|
||||
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
|
||||
self.assertGreater(mean, 0)
|
||||
self.assertGreater(variance, 0)
|
||||
if DEBUG >= 2: print(f"{name:20s} SE:{event.se} {mean/1e9:.2f} GHz mean, {variance/1e18:.2f} GHz^2 variance")
|
||||
events = [e for e in timeline if type(e).__name__ == "ProfileRangeEvent"]
|
||||
insts, execs = 0, 0
|
||||
@@ -90,10 +91,21 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
if "ALT" not in e.name.display_name: execs += 1
|
||||
elif "WAVE" in e.device:
|
||||
# sopk/immediates don't get ALU/MEM EXEC
|
||||
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE"}: insts += 1
|
||||
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL"}: insts += 1
|
||||
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
|
||||
self.assertEqual(execs, insts)
|
||||
|
||||
def test_wave_sync(self):
|
||||
for name, (events, kern_events, target) in self.examples.items():
|
||||
for event in events:
|
||||
wave_barriers = {}
|
||||
for e in sqtt_timeline(event.blob, kern_events[event.kern].lib, target):
|
||||
if type(e).__name__ == "ProfileRangeEvent" and e.name.display_name == "BARRIER": wave_barriers.setdefault(e.device, []).append(e)
|
||||
if not wave_barriers: continue
|
||||
for row, events in wave_barriers.items():
|
||||
for e in events:
|
||||
assert e.en-e.st > 1, f"all barriers must have a duration greater than 1, got {e}"
|
||||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
||||
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
|
||||
|
||||
@@ -128,6 +128,7 @@ class InstOpRDNA4(Enum):
|
||||
LDS_WR_5 = 0x2e
|
||||
OTHER_LDS_1 = 0x50
|
||||
OTHER_LDS_2 = 0x51
|
||||
BARRIER_SIGNAL = 0x7a
|
||||
WMMA_8 = 0x8c
|
||||
WMMA_16 = 0x8d
|
||||
VALU_DPFP = 0x92
|
||||
|
||||
@@ -342,13 +342,17 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4
|
||||
ret:list[ProfileEvent] = []
|
||||
row_ends:dict[str, Decimal] = {}
|
||||
curr_barrier:dict[str, ProfileRangeEvent] = {}
|
||||
NS_PER_TICK = 10 # 100MHz
|
||||
prev_pair:tuple[int, int]|None = None # (shader, realtime)
|
||||
def add(name:str, p:PacketType, width=1, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None:
|
||||
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None:
|
||||
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
|
||||
ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+width)))
|
||||
# barrier on this row extends to fill the time our wave was waiting
|
||||
if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time)
|
||||
ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1)))
|
||||
if (et:=row_ends.get(row)) is not None and e.st < et: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
|
||||
row_ends[row] = unwrap(e.en)
|
||||
if name == "BARRIER": curr_barrier[row] = e
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
|
||||
@@ -361,7 +365,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
prev_pair = pair
|
||||
if isinstance(p, (INST, INST_RDNA4)):
|
||||
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}"
|
||||
add(name, p, width=10 if "BARRIER" in name else 1, info=info)
|
||||
add(name, p, info=info)
|
||||
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
|
||||
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)):
|
||||
|
||||
Reference in New Issue
Block a user