Merge branch 'master' into typed_checks

This commit is contained in:
George Hotz
2025-12-22 13:24:10 -05:00
committed by GitHub
22 changed files with 545 additions and 268 deletions

View File

@@ -1,9 +0,0 @@
import globals from "globals";
import pluginJs from "@eslint/js";
import pluginHtml from "eslint-plugin-html";
export default [
{files: ["**/*.html"], plugins: {html: pluginHtml}, rules:{"max-len": ["error", {"code": 150}]}},
{languageOptions: {globals: globals.browser}},
pluginJs.configs.recommended,
];

View File

@@ -1438,28 +1438,34 @@ def train_llama3():
iter = get_train_iter()
i, sequences_seen = resume_ckpt, 0
for tokens in tqdm(iter, total=SAMPLES//GBS):
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens)
loss = loss.float().item()
if getenv("TRAIN", 1):
t = time.perf_counter()
loss, lr = train_step(model, tokens)
loss = loss.float().item()
lr = lr.item()
i += 1
sequences_seen += tokens.shape[0]
i += 1
sequences_seen += tokens.shape[0]
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
sec = time.perf_counter()-t
tqdm.write(
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3_{i}.safe"
safe_save(get_state_dict(model), fn)
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
tqdm.write("saving optim checkpoint")
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
safe_save(get_state_dict(scheduler), fn)
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3_{i}.safe"
safe_save(get_state_dict(model), fn)
tqdm.write("saving optim checkpoint")
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
safe_save(get_state_dict(scheduler), fn)
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
tqdm.write(f"evaluating after {sequences_seen} sequences")

View File

@@ -184,7 +184,7 @@ class SMICtx:
if compact: return {k: temps[k] for k in ("Hotspot", "HBM") if temps.get(k, 0) != 0}
return {k: v for k, v in temps.items() if v != 0}
case _:
temps_keys = [(k, name) for k, name in dev.smu.smu_mod.c__EA_TEMP_e__enumvalues.items()
temps_keys = [(k, name) for k, name in dev.smu.smu_mod.TEMP_e.items()
if k < dev.smu.smu_mod.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0]
if compact: temps_keys = [(k, name) for k, name in temps_keys if k in (dev.smu.smu_mod.TEMP_HOTSPOT, dev.smu.smu_mod.TEMP_MEM)]
return {name: metrics.SmuMetrics.AvgTemperature[k] for k, name in temps_keys}
@@ -193,7 +193,7 @@ class SMICtx:
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return {}
case _:
voltage_keys = [(k, name) for k, name in dev.smu.smu_mod.c__EA_SVI_PLANE_e__enumvalues.items()
voltage_keys = [(k, name) for k, name in dev.smu.smu_mod.SVI_PLANE_e.items()
if k < dev.smu.smu_mod.SVI_PLANE_COUNT and metrics.SmuMetrics.AvgVoltage[k] != 0]
return {name: metrics.SmuMetrics.AvgVoltage[k] for k, name in voltage_keys}

View File

@@ -5,29 +5,7 @@ from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEven
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
from tinygrad.runtime.autogen import llvm, rocprof
from tinygrad.runtime.support.elf import elf_loader
# to pass NULL to callbacks
llvm.LLVMCreateDisasmCPUFeatures.argtypes = tuple(llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5]) + (ctypes.c_void_p, ctypes.c_void_p)
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC()
llvm.LLVMInitializeAMDGPUAsmParser()
llvm.LLVMInitializeAMDGPUDisassembler()
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, None, None)
image, sections, relocs = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), None)
off, sz = unwrap(text).sh_addr, unwrap(text).sh_size
addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128)
cur_off = off
while cur_off < sz + off:
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
cur_off += instr_sz
return addr_table
from tinygrad.viz.serve import llvm_disasm
@dataclasses.dataclass(frozen=True)
class InstExec:

1
opencode.json Normal file
View File

@@ -0,0 +1 @@
{"$schema": "https://opencode.ai/config.json", "formatter": false}

View File

@@ -60,7 +60,7 @@ def universal_test(a, b, dtype, op):
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
tensor_value = (op[0](ta, tb)).numpy()
numpy_value = op[1](ta.numpy(), tb.numpy())
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
if dtype in dtypes.floats:
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7))
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
@@ -77,8 +77,8 @@ def universal_test_unary(a, dtype, op):
numpy_value = op[1](ta.numpy())
if dtype in dtypes.fp8s:
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
if math.isinf(numpy_value): return
numpy_value = truncate[dtype](numpy_value)
if math.isinf(numpy_value.item()): return
numpy_value = truncate[dtype](numpy_value.item())
if dtype in dtypes.floats:
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))

View File

@@ -0,0 +1,172 @@
import unittest
import textwrap
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, track_rewrites
from tinygrad.renderer import ProgramSpec
from tinygrad.helpers import TracingKey
from tinygrad.engine.realize import ExecItem, CompiledRunner
# TODO: use the RDNA3 renderer when it's in master
template = """.text
.globl fn_name
.p2align 8
.type fn_name,@function
fn_name:
INSTRUCTION
.rodata
.p2align 6
.amdhsa_kernel fn_name
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_wavefront_size32 1
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: fn_name
.symbol: fn_name.kd
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 8
.max_flat_workgroup_size: 1024
.kernarg_segment_align: 8
.kernarg_segment_size: 8
.args:
- .address_space: global
.name: a
.offset: 0
.size: 8
.type_name: 'float*'
.value_kind: global_buffer
...
.end_amdgpu_metadata
"""
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
def run_asm(name:str, src:str) -> ProgramSpec:
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK))
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
ei.run()
return prg
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
class TestCfg(unittest.TestCase):
def setUp(self):
arch = Device["AMD"].arch
if not any(arch.startswith(a) for a in {"gfx11", "gfx12"}):
self.skipTest(f"tests written for RDNA, got arch {arch}")
def test_simple(self):
run_asm("simple", """
entry:
s_branch bb1
bb1:
s_endpgm
""")
def test_diamond(self):
run_asm("diamond", """
entry:
s_cmp_eq_i32 s0, 0
s_cbranch_scc1 if
s_branch else
if:
s_nop 1
s_branch end
else:
s_nop 0
end:
s_endpgm
""")
def test_loop(self):
run_asm("simple_loop", """
entry:
s_mov_b32 s1, 4
loop:
s_add_u32 s1, s1, -1
s_cmp_eq_i32 s1, 0
s_cbranch_scc0 loop
s_endpgm
""")
def test_loop_branch(self):
run_asm("loop_if", """
entry:
s_mov_b32 s1, 4
loop:
s_add_u32 s1, s1, -1
s_cmp_eq_i32 s1, 2
s_cbranch_scc1 cond
s_branch cont
cond:
s_add_u32 s1, s1, -2
cont:
s_cmp_eq_i32 s1, 0
s_cbranch_scc0 loop
s_endpgm
""")
def test_loop_break(self):
run_asm("loop_break", """
entry:
s_mov_b32 s1, 8
loop:
s_add_u32 s1, s1, -1
s_cmp_eq_i32 s1, 5
s_cbranch_scc1 break
s_cmp_eq_i32 s1, 0
s_cbranch_scc0 loop
break:
s_endpgm
""")
def test_switch(self):
run_asm("switch_case", """
entry:
s_cmp_eq_i32 s0, 0
s_cbranch_scc1 case0
s_cmp_eq_i32 s0, 1
s_cbranch_scc1 case1
s_branch case2
case0:
s_nop 0
s_branch join
case1:
s_nop 1
s_branch join
case2:
s_nop 2
s_branch join
join:
s_endpgm
""")
def test_ping_pong(self):
run_asm("ping_pong", """
entry:
s_cmp_eq_i32 s0, 0
s_cbranch_scc1 ping
s_branch pong
ping:
s_cmp_eq_i32 s1, 0
s_cbranch_scc1 pong
s_branch end
pong:
s_cmp_eq_i32 s2, 0
s_cbranch_scc1 ping
end:
s_endpgm
""")
if __name__ == "__main__":
unittest.main()

View File

@@ -66,5 +66,21 @@ class TestDtypeTolist(unittest.TestCase):
# 57344
self.assertEqual(Tensor([-30000, 1.5, 3.1, 30000], device="PYTHON", dtype=dtypes.fp8e5m2).tolist(), [-28672.0, 1.5, 3.0, 28672.0])
class TestCanLosslessCast(unittest.TestCase):
def test_can_lossless_cast(self):
from tinygrad.dtype import can_lossless_cast
# signed -> unsigned is NOT lossless (negative values wrap)
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.uint64))
self.assertFalse(can_lossless_cast(dtypes.int32, dtypes.uint32))
# unsigned -> larger signed is lossless
self.assertTrue(can_lossless_cast(dtypes.uint8, dtypes.int16))
self.assertTrue(can_lossless_cast(dtypes.uint32, dtypes.int64))
# large ints don't fit in floats
self.assertFalse(can_lossless_cast(dtypes.int32, dtypes.float))
self.assertFalse(can_lossless_cast(dtypes.int64, dtypes.double))
# half has more mantissa bits
self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half))
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16))
if __name__ == "__main__":
unittest.main()

53
test/unit/test_llm_moe.py Normal file
View File

@@ -0,0 +1,53 @@
import unittest
import numpy as np
from tinygrad import Tensor
class TestMoEFeedForward(unittest.TestCase):
def test_moe_feed_forward(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
# set up weights: gate scales by (expert_id+1), up/down are identity-ish, router picks experts 0,2
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T # router strongly prefers experts 0 and 2
block.ffn_norm.weight = Tensor.ones(dim) # identity norm
# input of ones -> after norm still ~ones -> experts 0,2 selected -> weighted sum of silu outputs
h = Tensor.ones(1, 1, dim)
out = block._feed_forward(h)
# expected: residual + moe_output ≈ 1 + avg(silu(1), silu(3))
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
def test_moe_feed_forward_batched(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
# same setup as BS=1 test
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T
block.ffn_norm.weight = Tensor.ones(dim)
# test with BS=2, T=3
h = Tensor.ones(2, 3, dim)
out = block._feed_forward(h)
# all outputs should match the BS=1 expected value
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
if __name__ == '__main__':
unittest.main()

View File

@@ -95,21 +95,24 @@ class TestUOpResolve(unittest.TestCase):
x = UOp.variable("i", 1, 10)
self.assertFalse(x < x)
@unittest.expectedFailure
def test_x_lt_xp1(self):
x = UOp.variable("i", 1, 10)
self.assertTrue(x < (x+1))
u = x < (x+1)
# TODO: improve
with self.assertRaises(ValueError):
bool(u)
def test_and_true(self):
u = UOp.variable("b", False, True, dtypes.bool) & True
with self.assertRaises(ValueError):
u = UOp.variable("b", False, True, dtypes.bool) & True
self.assertFalse(u)
bool(u)
@unittest.expectedFailure
def test_var_cmp_range(self):
v = UOp.variable("i", 1, 10)
u = (v > 4) | (v < 6)
self.assertTrue(u)
# TODO: improve
with self.assertRaises(ValueError):
bool(u)
def test_var_cmp_assert(self):
with self.assertRaises(ValueError):

View File

@@ -1,35 +0,0 @@
const { spawn } = require("child_process");
const puppeteer = require("puppeteer");
async function main() {
// ** start viz server
const proc = spawn("python", ["-u", "-c", "from tinygrad import Tensor; Tensor.arange(4).realize()"], { env: { ...process.env, VIZ:"1" },
stdio: ["inherit", "pipe", "inherit"]});
await new Promise(resolve => proc.stdout.on("data", r => {
if (r.includes("ready")) resolve();
}));
// ** run browser tests
let browser, page;
try {
browser = await puppeteer.launch({ headless: true });
page = await browser.newPage();
const res = await page.goto("http://localhost:8000", { waitUntil:"domcontentloaded" });
if (res.status() !== 200) throw new Error("Failed to load page");
const scheduleSelector = await page.waitForSelector("ul:nth-of-type(2)");
scheduleSelector.click();
await page.waitForSelector("rect");
await page.waitForFunction(() => {
const nodes = document.querySelectorAll("#nodes > g").length;
const edges = document.querySelectorAll("#edges > path").length;
return nodes > 0 && edges > 0;
});
} finally {
// ** cleanups
if (page != null) await page.close();
if (browser != null) await browser.close();
proc.kill();
}
}
main();

View File

@@ -5,7 +5,7 @@ from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler,
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
@@ -52,9 +52,13 @@ class SimpleTokenizer:
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
def role(self, role:str):
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
def end_turn(self, eos_id:int): return [eos_id] + self.encode("\n") if self.preset == 'qwen2' else [eos_id]
def end_turn(self, eos_id:int):
if self.preset == 'olmo': return self.encode("\n")
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
return [eos_id]
@functools.cache
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
@@ -62,6 +66,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
class ExpertWeights:
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
def __init__(self, num_experts:int, in_features:int, out_features:int):
self.weight = Tensor.zeros(num_experts, out_features, in_features)
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
@@ -70,12 +82,13 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:bool=False):
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.max_context = max_context
self.rope_theta = rope_theta
self.max_context = max_context
self.qk_norm = qk_norm
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = self.head_dim * n_heads
@@ -88,23 +101,30 @@ class TransformerBlock:
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
# --- feed-forward ----------------------------------------------------
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
# --- feed-forward (MoE or dense) -------------------------------------
if num_experts > 0:
self.num_experts_per_tok = num_experts_per_tok
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim)
self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim)
else:
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
B, T, _ = x.shape
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if hasattr(self, 'attn_q_norm'): q, k = self.attn_q_norm(q), self.attn_k_norm(k)
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
# TODO: make UOp have SupportsIndex
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
@@ -127,6 +147,11 @@ class TransformerBlock:
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
if hasattr(self, 'ffn_gate_exps'):
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
# TODO: remove the need for this contiguous
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
return h + self.ffn_down(gated)
@@ -136,9 +161,9 @@ class TransformerBlock:
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:bool=False):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm)
for _ in range(num_blocks)]
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
self.token_embd = nn.Embedding(vocab_size, dim)
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
@@ -170,16 +195,20 @@ class Transformer:
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
if arch != 'qwen3':
# Permute Q/K weights from interleaved to half-split RoPE layout (llama-style models only)
if arch == 'llama':
for name in state_dict:
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict)
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
@@ -207,6 +236,8 @@ models = {
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
}
# *** simple OpenAI compatible server on 11434 to match ollama ***

View File

@@ -219,17 +219,19 @@ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType)
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
@functools.cache
def can_safe_cast(dt0:DType, dt1:DType) -> bool:
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
# return if dt1 preserves value of dt0
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
if dt0 == dt1 or dt0 == dtypes.bool: return True
match dt1:
case dtypes.index: return dt0 in dtypes.ints
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, *dtypes.fp8s,
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
case dtypes.half: return dt0 in (*dtypes.fp8s, dtypes.uint8, dtypes.int8)
case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8)
case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8)
case dtypes.uint16: return dt0 in (dtypes.uint8,)
case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8)
@@ -320,11 +322,8 @@ def fp8_to_float(x: int, dtype: DType) -> float:
truncate: dict[DType, Callable] = {dtypes.bool: bool,
dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
dtypes.int64: lambda x: ctypes.c_int64(x).value}
**{getattr(dtypes, n): (lambda x, c=getattr(ctypes, f'c_{n}'): c(x).value)
for n in ('float', 'double', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64')}}
# numpy and torch dtype interop

View File

@@ -20,8 +20,9 @@ from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_so
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta, PCIDevice, USBPCIDevice, MAP_FIXED, MAP_NORESERVE
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
SQTT, SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT", VIZ.value>=2), ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
PMC = ContextVar("PMC", VIZ.value>=2)
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
WAIT_REG_MEM_FUNCTION_NEQ = 4 # !=
@@ -829,8 +830,8 @@ class PCIIface(PCIIfaceBase):
doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0), pipe=0, queue=0)
else:
pv = self.dev_impl.gfx.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr, wptr_addr=gart.va_addr+wptr,
eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_MEC_RING0), pipe=0, queue=0,
aql=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL))
eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_MEC_RING0), pipe=0,
queue=int(is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL)), aql=is_aql)
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv)

View File

@@ -142,7 +142,7 @@ class AMMemoryManager(MemoryManager):
self.dev.gmc.flush_tlb(ip='MM', vmid=0)
class AMDev(PCIDevImplBase):
Version = 0xA0000006
Version = 0xA0000007
def __init__(self, pci_dev:PCIDevice, dma_regions:list[tuple[int, MMIOInterface]]|None=None, reset_mode=False):
self.pci_dev, self.devfmt, self.dma_regions = pci_dev, pci_dev.pcibus, dma_regions

View File

@@ -212,14 +212,19 @@ class AM_SMU(AM_IP):
return (self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).read() if read_back_arg else None
class AM_GFX(AM_IP):
def init_sw(self): self.xccs = len(self.adev.regs_offset[am.GC_HWIP])
def init_sw(self):
self.xccs = len(self.adev.regs_offset[am.GC_HWIP])
self.mqd_paddr = [self.adev.mm.palloc(0x1000 * self.xccs, zero=False, boot=True) for i in range(2)]
self.mqd_mc = [self.adev.paddr2mc(mqd_paddr) for mqd_paddr in self.mqd_paddr]
def init_hw(self):
# Wait for RLC autoload to complete
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
self._config_gfx_rs64()
self.adev.gmc.init_hub("GC", inst_cnt=self.xccs)
if self.adev.partial_boot: return
self._config_gfx_rs64()
# NOTE: Golden reg for gfx11. No values for this reg provided. The kernel just ors 0x20000000 to this reg.
for xcc in range(self.xccs): self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000, inst=xcc)
@@ -265,49 +270,61 @@ class AM_GFX(AM_IP):
if self.xccs > 1 and not self.adev.partial_boot: self.adev.psp._spatial_partition_cmd(1)
def fini_hw(self):
for xcc in range(self.xccs):
self._grbm_select(me=1, pipe=0, queue=0, inst=xcc)
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1: self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
self._grbm_select(inst=xcc)
# NOTE: For aqls with xccs (queue=1), will continue from the saved state.
for q in range(2 if self.xccs == 1 else 1):
for xcc in range(self.xccs):
self._grbm_select(me=1, pipe=0, queue=q, inst=xcc)
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1: self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
self._grbm_select(inst=xcc)
for xcc in range(self.xccs): self.adev.regGCVM_CONTEXT0_CNTL.write(0, inst=xcc)
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int,
aql:bool) -> int:
for xcc in range(self.xccs if aql else 1):
mqd = self.adev.mm.valloc(0x1000, uncached=True, contiguous=True)
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=0)
restore_queue = aql and self.xccs > 1 and self.adev.partial_boot and (self.adev.regCP_HQD_ACTIVE.read(inst=0) & 1)
restore_ptr = (self.adev.regCP_HQD_PQ_WPTR_LO.read(inst=0) | (self.adev.regCP_HQD_PQ_WPTR_HI.read(inst=0) << 32)) if restore_queue else 0
if DEBUG >= 2 and restore_queue: print(f"am {self.adev.devfmt}: GFX queue already active, continuing from saved state {restore_ptr=:#x}.")
for xcc in range(self.xccs if aql else 1):
struct_t = getattr(am, f"struct_v{self.adev.ip_ver[am.GC_HWIP][0]}{'_compute' if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else ''}_mqd")
mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(self.mqd_mc[queue] + 0x1000*xcc),
cp_mqd_base_addr_hi=hi32(self.mqd_mc[queue] + 0x1000*xcc), cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2,
**({'queue_full_en':1, 'slot_based_wptr':2, 'wpp_clamp_en':1, 'no_update_rptr':xcc!=0 or self.xccs==1} if aql else {})),
**({'queue_full_en':1, 'slot_based_wptr':2, 'no_update_rptr':xcc!=0 or self.xccs==1} if aql else {})),
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0, cp_hqd_aql_control=int(aql),
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2),
**({'compute_tg_chunk_size':1, 'compute_current_logic_xcc_id':xcc} if aql and self.xccs > 1 else {}))
**({'compute_tg_chunk_size':1, 'compute_current_logic_xcc_id':xcc, 'cp_mqd_stride_size':0x1000} if aql and self.xccs > 1 else {}))
for se in range(8 if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else 4): setattr(mqd_struct, f'compute_static_thread_mgmt_se{se}', 0xffffffff)
# Copy mqd into memory
self.adev.vram.view(mqd.paddrs[0][0], ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
self.adev.gmc.flush_hdp()
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=xcc)
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)):
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc)
if restore_queue:
for r in [self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR_HI,
self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR_HI, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR_HI]:
val = memoryview(bytes(mqd_struct)).cast('I')[0x80 + (off:=r.addr[xcc] - self.adev.regCP_MQD_BASE_ADDR.addr[xcc])]
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct), fmt='I')[0x80 + off] = val
r.write(val, inst=xcc)
else:
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)):
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc)
self.adev.gmc.flush_hdp()
self._grbm_select(inst=xcc)
self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1, inst=xcc)
return 0
return restore_ptr // 16
def set_clockgating_state(self):
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
@@ -476,7 +493,7 @@ class AM_PSP(AM_IP):
if not self.is_sos_alive():
for fw, compid in sos_components: self._bootloader_load_component(fw, compid)
while not self.is_sos_alive(): time.sleep(0.01)
wait_cond(self.is_sos_alive, value=True, msg="sOS failed to start")
self._ring_create()
if am.PSP_FW_TYPE_PSP_TOC in self.adev.fw.sos_fw: self._tmr_init()

View File

@@ -1174,7 +1174,7 @@ if TRACK_MATCH_STATS or PROFILE:
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):

View File

@@ -2,7 +2,7 @@
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
@@ -68,6 +68,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
(UPat.var("x") ^ UPat.var("x"), lambda x: x.const_like(0)), # x^x -> 0
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding **
@@ -97,7 +98,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
# b.cast(a).cast(b) -> b if a preserves all values in b
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None),
# ** pow **
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
# positive const ** x
@@ -242,7 +243,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
# cast/long folding
# if the intermediate cast doesnt narrow we can do it in one cast
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None),
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None),
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
# try to do math in int instead of long

View File

@@ -124,7 +124,6 @@
fill: rgba(26, 27, 38, 0.5);
}
.edgePath {
stroke: #4a4b57;
fill: none;
stroke-width: 1.4px;
}
@@ -322,14 +321,6 @@
td.Instruction {
font-family: monospace;
}
td.pct-row > div {
height: 12px;
width: 100%;
display: flex;
}
td.pct-row > div > div {
height: 100%;
}
thead {
position: sticky;
top: 0;
@@ -350,12 +341,11 @@
tr.nested-row table tr.main-row:hover {
background-color: unset;
}
tr.main-row.has-children > td:first-child {
white-space: pre;
tr.main-row.has-children > td:first-child > p {
display: inline-block;
}
tr.main-row.has-children > td:first-child::before {
content: "▸ ";
display: inline-block;
width: 1em;
margin-left: -0.25em;
}

View File

@@ -26,18 +26,17 @@ const colored = n => d3.create("span").call(s => s.selectAll("span").data(typeof
const rect = (s) => (typeof s === "string" ? document.querySelector(s) : s).getBoundingClientRect();
let timeout = null;
const updateProgress = ({ start, err }) => {
const Status = {STARTED:0, COMPLETE:1, ERR:2}
const updateProgress = (st, msg) => {
clearTimeout(timeout);
const msg = document.getElementById("progress-message");
msg.style.display = "none";
if (start) {
msg.innerText = "Rendering new graph...";
timeout = setTimeout(() => { msg.style.display = "block"; }, 2000);
}
d3.select("#custom").html("");
if (err) {
const msgEl = d3.select("#progress-message").style("display", "none");
const customEl = d3.select("#custom").html("");
if (st === Status.STARTED) {
msgEl.text(msg);
timeout = setTimeout(() => msgEl.style("display", "block"), 2000);
} else if (st === Status.ERR) {
displaySelection("#custom");
d3.select("#custom").append("div").classed("raw-text", true).call(s => s.append(() => codeBlock(err, "txt"))).node();
customEl.append("div").classed("raw-text", true).append(() => codeBlock(msg));
}
}
@@ -77,23 +76,19 @@ const drawGraph = (data) => {
.attr("x", d => -d.width/2).attr("y", d => -d.height/2);
const STROKE_WIDTH = 1.4;
const labels = nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label");
const hasLabelDims = data.nodes[0]?.value.labelWidth != null;
if (hasLabelDims) labels.attr("transform", d => `translate(-${d.labelWidth/2}, -${d.labelHeight/2+STROKE_WIDTH*2})`);
labels.attr("transform", d => `translate(-${d.labelWidth/2}, -${d.labelHeight/2+STROKE_WIDTH*2})`);
labels.selectAll("text").data(d => {
if (Array.isArray(d.label)) return [d.label];
const ret = [[]];
for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
const lines = st.split("\n");
for (const s of parseColors(d.label, defaultColor="initial")) {
const color = darkenHex(s.color, 25);
const lines = s.st.split("\n");
ret.at(-1).push({ st:lines[0], color });
for (let i=1; i<lines.length; i++) ret.push([{ st:lines[i], color }]);
}
return [ret];
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
.attr("fill", d => darkenHex(d.color, 25)).text(d => d.st).attr("xml:space", "preserve");
// recenter after drawing texts if needed
if (!hasLabelDims) labels.attr("transform", (_,i,els) => {
const b = els[i].getBBox();
return `translate(${-b.x-b.width/2}, ${-b.y-b.height/2})`
});
.attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve").style("font-family", g.graph().font);
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
// draw edges
@@ -104,7 +99,7 @@ const drawGraph = (data) => {
points.unshift(intersectRect(g.node(e.v), points[0]));
points.push(intersectRect(g.node(e.w), points[points.length-1]));
return line(points);
}).attr("marker-end", "url(#arrowhead)");
}).attr("marker-end", "url(#arrowhead)").attr("stroke", e => g.edge(e).color || "#4a4b57");
}
// ** UOp graph
@@ -115,15 +110,15 @@ async function initWorker() {
workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" }));
}
function renderDag(graph, additions, recenter, layoutOpts) {
function renderDag(layoutSpec, { recenter }) {
// start calculating the new layout (non-blocking)
updateProgress({ start:true });
updateProgress(Status.STARTED, "Rendering new graph...");
if (worker != null) worker.terminate();
worker = new Worker(workerUrl);
worker.postMessage({graph, additions, opts:layoutOpts });
worker.postMessage(layoutSpec);
worker.onmessage = (e) => {
displaySelection("#graph");
updateProgress({ start:false });
updateProgress(Status.COMPLETE);
drawGraph(e.data);
addTags(d3.select("#edge-labels").selectAll("g").data(e.data.edges).join("g").attr("transform", (e) => {
// get a point near the end
@@ -144,7 +139,7 @@ function renderDag(graph, additions, recenter, layoutOpts) {
};
worker.onerror = (e) => {
e.preventDefault();
updateProgress({ err:"Error in graph layout:\n"+e.message });
updateProgress(Status.ERR, "Error in graph layout:\n"+e.message);
}
}
@@ -159,8 +154,7 @@ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SE:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
CATEGORICAL:["#ff8080", "#F4A261", "#C8F9D4", "#8D99AE", "#F4A261", "#ffffa2", "#ffffc0", "#87CEEB"],}
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SE:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),}
const cycleColors = (lst, i) => lst[i%lst.length];
const rescaleTrack = (source, tid, k) => {
@@ -257,7 +251,7 @@ async function renderProfiler(path, unit, opts) {
if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; }
metadata.replaceChildren(getMetadata(focusedShape));
// layout once!
if (data.tracks.size !== 0) return updateProgress({ start:false });
if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE);
const profiler = d3.select("#profiler").html("");
const buf = cache[path] ?? await fetchValue(path);
const view = new DataView(buf);
@@ -419,7 +413,7 @@ async function renderProfiler(path, unit, opts) {
}
}
for (const m of markers) m.label = m.name.split(/(\s+)/).map(st => ({ st, color:m.color, width:ctx.measureText(st).width }));
updateProgress({ start:false });
updateProgress(Status.COMPLETE);
// draw events on a timeline
const dpr = window.devicePixelRatio || 1;
const ellipsisWidth = ctx.measureText("...").width;
@@ -604,7 +598,7 @@ const pathLink = (fp, lineno) => d3.create("a").attr("href", "vscode://file/"+fp
function codeBlock(st, language, { loc, wrap }={}) {
const code = document.createElement("code");
// plaintext renders like a terminal print, otherwise render with syntax highlighting
if (language === "txt") code.appendChild(colored(st));
if (!language || language === "txt") code.appendChild(colored(st));
else code.innerHTML = hljs.highlight(st, { language }).value;
code.className = "hljs";
const ret = document.createElement("pre");
@@ -636,9 +630,6 @@ hljs.registerLanguage("cpp", (hljs) => ({
...hljs.getLanguage('cpp'),
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
hljs.registerLanguage("amdgpu", (hljs) => ({
contains: [hljs.COMMENT("//", "$"), { begin:/\b(?:s_|v_|global_|buffer_|scratch_|flat_|ds_)[a-z0-9_]*\b/, className:"code" }]
}));
async function fetchValue(path) {
const res = await fetch(path);
@@ -800,25 +791,19 @@ async function main() {
}
const td = tr.append("td").classed(ret.cols[i], true);
// string format scalar values
if (!Array.isArray(value)) { td.text(typeof value === "string" ? value : ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)); continue; }
// display arrays in a bar graph
td.classed("pct-row", true);
const bar = td.append("div");
value.forEach(([k, v, width]) => bar.append("div").style("width", width+"%").attr("title", `${ret.cols[i].labels[k]} ${v}`)
.style("background", cycleColors(colorScheme.CATEGORICAL, parseInt(k))))
td.append(() => typeof value === "string" ? colored(value) : d3.create("p").text(ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)).node());
}
}
return table;
}
if (ret.cols != null) {
renderTable(root, ret);
} else root.append(() => codeBlock(ret.src, ret.lang || "txt"));
if (ret.cols != null) renderTable(root, ret);
else if (ret.data != null) renderDag(ret, { recenter:true });
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
ret.metadata?.forEach(m => {
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value, idx }) => {
const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, idx)).style("width", "100%").style("height", "100%");
return [label.trim(), div.text(typeof value === "string" ? value : formatUnit(value)).node()];
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
})).node());
metadata.appendChild(codeBlock(m.src, "txt")).classList.add("full-height")
metadata.appendChild(codeBlock(m.src)).classList.add("full-height")
});
return document.querySelector("#custom").replaceChildren(root.node());
}
@@ -843,7 +828,7 @@ async function main() {
if (ret.length === 0) return;
// ** center graph
const data = ret[currentRewrite];
const render = (opts) => renderDag(data.graph, data.changed_nodes ?? [], currentRewrite === 0, opts);
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
render({ showIndexing:toggle.checked });
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
// ** right sidebar metadata

View File

@@ -1,27 +1,57 @@
const NODE_PADDING = 10;
const rectDims = (lw, lh) => ({ width:lw+NODE_PADDING*2, height:lh+NODE_PADDING*2, labelWidth:lw, labelHeight:lh });
const LINE_HEIGHT = 14;
const canvas = new OffscreenCanvas(0, 0);
const ctx = canvas.getContext("2d");
ctx.font = `350 ${LINE_HEIGHT}px sans-serif`;
onmessage = (e) => {
const { graph, additions, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
const { data, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
postMessage(dagre.graphlib.json.write(g));
self.close();
}
const layoutCfg = (g, { blocks, paths, pc_table, colors }) => {
g.setGraph({ rankdir:"TD", font:"monospace" });
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
// basic blocks render the assembly in nodes
for (const [lead, members] of Object.entries(blocks)) {
let [width, height, label] = [0, 0, []];
for (const m of members) {
const text = pc_table[m][0];
width = Math.max(width, ctx.measureText(text).width);
height += LINE_HEIGHT;
const [inst, ...operands] = text.split(" ");
label.push([{st:inst+" ", color:"#7aa2f7"}, {st:operands.join(" "), color:"#9aa5ce"}]);
}
g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" });
}
// paths become edges between basic blocks
for (const [lead, value] of Object.entries(paths)) {
for (const [id, color] of Object.entries(value)) g.setEdge(lead, id, {label:{type:"port", text:""}, color:colors[color]});
}
dagre.layout(g);
}
const layoutUOp = (g, { graph, change }, opts) => {
g.setGraph({ rankdir: "LR", font:"sans-serif" });
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
for (const [k, {label, src, ref, color, tag }] of Object.entries(graph)) {
// adjust node dims by label size (excluding escape codes) + add padding
let [width, height] = [0, 0];
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
width = Math.max(width, ctx.measureText(line).width);
height += LINE_HEIGHT;
}
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest});
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag});
// add edges
const edgeCounts = {}
const edgeCounts = {};
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
}
// optionally hide nodes from the layuot
if (!opts.showIndexing) {
@@ -31,8 +61,6 @@ onmessage = (e) => {
}
}
dagre.layout(g);
// remove additions overlay if it's empty
if (!g.node("addition")?.width) g.removeNode("addition");
postMessage(dagre.graphlib.json.write(g));
self.close();
// remove overlay node if it's empty
if (!g.node("overlay")?.width) g.removeNode("overlay");
}

View File

@@ -6,7 +6,7 @@ from decimal import Decimal
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, system, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
@@ -53,7 +53,7 @@ class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
change: list[int]|None # the new UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
@@ -115,14 +115,14 @@ def _reconstruct(a:int):
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None}
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink
@@ -248,23 +248,25 @@ def load_counters(profile:list[ProfileEvent]) -> None:
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), \
(durations, {k:v[ProfilePMCEvent][0] for k,v in counter_events.items()}))]})
if len(counter_events) == 0: return None
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
run_number = {n:0 for n,_ in counter_events}
for k,v in counter_events.items():
prg = trace.keys[r].ret if (r:=ref_map.get(k[0])) else None
name = prg.name if prg is not None else k[0]
run_number[k[0]] += 1
for (k, tag),v in counter_events.items():
# use the colored name if it exists
name = trace.keys[r].ret.name if (r:=ref_map.get(k)) is not None else k
run_number[k] += 1
steps:list[dict] = []
if (pmc:=v.get(ProfilePMCEvent)): steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
if (pmc:=v.get(ProfilePMCEvent)):
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
all_counters[(name, run_number[k], k)] = pmc[0]
if (sqtt:=v.get(ProfileSQTTEvent)):
# to decode a SQTT trace, we need the raw stream, program binary and device properties
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), (k, [*sqtt, prg_events[k[0]], dev_events[sqtt[0].device]])))
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]])))
if getenv("SQTT_PARSE"):
# run our decoder on startup, we don't use this since it only works on gfx11
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
for e in sqtt: parse_sqtt_print_packets(e.blob)
ctxs.append({"name":f"Exec {name} n{run_number[k[0]]}", "steps":steps})
ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps})
# ** SQTT OCC only unpacks wave start, end time and SIMD location
@@ -337,26 +339,6 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
# ** Assembly static analyzers
def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
target_args = f"-mtriple={mtriple} -mcpu={mcpu}"
# disassembly output can include headers / metadata, skip if llvm-mca can't parse those lines
data = json.loads(system("llvm-mca -skip-unsupported-instructions=parse-failure --json -"+target_args, input=asm.encode()))
cr = data["CodeRegions"][0]
resource_labels = [repr(x)[1:-1] for x in data["TargetInfo"]["Resources"]]
rows:list = [[instr] for instr in cr["Instructions"]]
# add scheduler estimates
for info in cr["InstructionInfoView"]["InstructionList"]: rows[info["Instruction"]].append(info["Latency"])
# map per instruction resource usage
instr_usage:dict[int, dict[int, int]] = {}
for d in cr["ResourcePressureView"]["ResourcePressureInfo"]:
instr_usage.setdefault(i:=d["InstructionIndex"], {}).setdefault(r:=d["ResourceIndex"], 0)
instr_usage[i][r] += d["ResourceUsage"]
# last row is the usage summary
summary = [{"idx":k, "label":resource_labels[k], "value":v} for k,v in instr_usage.pop(len(rows), {}).items()]
max_usage = max([sum(v.values()) for i,v in instr_usage.items() if i<len(rows)], default=0)
for i,usage in instr_usage.items(): rows[i].append([[k, v, (v/max_usage)*100] for k,v in usage.items()])
return {"rows":rows, "cols":["Instruction", "Latency", {"title":"HW Resources", "labels":resource_labels}], "metadata":[summary]}
def get_stdout(f: Callable) -> str:
buf = io.StringIO()
try:
@@ -376,6 +358,67 @@ def amd_readelf(lib:bytes) -> list[dict]:
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader
llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC()
llvm.LLVMInitializeAMDGPUAsmParser()
llvm.LLVMInitializeAMDGPUDisassembler()
# pass NULL to callbacks
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
image, sections, _ = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), None)
assert text is not None, "no .text section found in ELF"
off, sz = text.sh_addr, text.sh_size
addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128)
cur_off = off
while cur_off < sz + off:
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
cur_off += instr_sz
return addr_table
SOPP_INSTS = {"s_branch", "s_cbranch_scc0", "s_cbranch_scc1", "s_cbranch_vccz", "s_cbranch_vccnz", "s_cbranch_execz", "s_cbranch_execnz"}
def parse_branch(asm:str) -> int|None:
inst, *operands = asm.split(" ")
if inst in SOPP_INSTS:
x = int(operands[0]) & 0xffff
return (x - 0x10000 if x & 0x8000 else x)*4
return None
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
def amdgpu_cfg(lib:bytes, arch:str) -> dict:
# disassemble
pc_table = llvm_disasm(arch, lib)
# get leaders
leaders:set[int] = {next(iter(pc_table))}
for pc, (asm, sz) in pc_table.items():
if (offset:=parse_branch(asm)) is not None: leaders.update((pc+sz+offset, pc+sz))
# build the cfg
curr:int|None = None
blocks:dict[int, list[int]] = {}
paths:dict[int, dict[int, int]] = {}
for pc, (asm, sz) in pc_table.items():
if pc in leaders:
paths[curr:=pc] = {}
blocks[pc] = []
else: assert curr is not None, f"no basic block found for {pc}"
blocks[curr].append(pc)
# control flow ends in endpgm
if asm == "s_endpgm": break
# otherwise a basic block can have exactly one or two paths
nx = pc+sz
if (offset:=parse_branch(asm)) is not None:
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
elif nx in leaders: paths[curr][nx] = UNCOND
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
# ** Main render function to get the complete details about a trace event
def get_render(i:int, j:int, fmt:str) -> dict:
@@ -386,22 +429,19 @@ def get_render(i:int, j:int, fmt:str) -> dict:
if fmt == "asm":
compiler = Device[data.device].compiler
disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(data.src)))
from tinygrad.runtime.support.compiler_cpu import llvm, LLVMCompiler
if isinstance(compiler, LLVMCompiler):
return get_llvm_mca(disasm_str, ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode(),
ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode())
metadata:list = []
ret:dict = {"src":disasm_str}
if data.device.startswith("AMD"):
with soft_err(lambda err: metadata.append(err)):
metadata.append(amd_readelf(compiler.compile(data.src)))
return {"src":disasm_str, "lang":"amdgpu" if data.device.startswith("AMD") else None, "metadata":metadata}
with soft_err(lambda err: ret.update(err)):
metadata = amd_readelf(lib:=compiler.compile(data.src))
ret = {"data":amdgpu_cfg(lib, getattr(compiler, "arch")), "metadata":[metadata]}
return ret
if fmt == "all-pmc":
durations, pmc = data
ret:dict = {"cols":{}, "rows":[]}
for (prg,_),events in pmc.items():
ret = {"cols":{}, "rows":[]}
for (name, n, k),events in data[1].items():
pmc_table = unpack_pmc(events)
ret["cols"].update([(r[0], None) for r in pmc_table["rows"]])
ret["rows"].append((prg, durations[prg].pop(0), *[r[1] for r in pmc_table["rows"]]))
ret["rows"].append((name, durations[k][n-1], *[r[1] for r in pmc_table["rows"]]))
ret["cols"] = ["Kernel", "Duration", *ret["cols"]]
return ret
if fmt == "prg-pmc": return unpack_pmc(data[0])