mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Merge branch 'master' into typed_checks
This commit is contained in:
@@ -1,9 +0,0 @@
|
||||
import globals from "globals";
|
||||
import pluginJs from "@eslint/js";
|
||||
import pluginHtml from "eslint-plugin-html";
|
||||
|
||||
export default [
|
||||
{files: ["**/*.html"], plugins: {html: pluginHtml}, rules:{"max-len": ["error", {"code": 150}]}},
|
||||
{languageOptions: {globals: globals.browser}},
|
||||
pluginJs.configs.recommended,
|
||||
];
|
||||
@@ -1438,28 +1438,34 @@ def train_llama3():
|
||||
iter = get_train_iter()
|
||||
i, sequences_seen = resume_ckpt, 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens)
|
||||
loss = loss.float().item()
|
||||
if getenv("TRAIN", 1):
|
||||
t = time.perf_counter()
|
||||
loss, lr = train_step(model, tokens)
|
||||
loss = loss.float().item()
|
||||
lr = lr.item()
|
||||
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
|
||||
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
sec = time.perf_counter()-t
|
||||
tqdm.write(
|
||||
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
|
||||
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
|
||||
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
tqdm.write("saving optim checkpoint")
|
||||
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
||||
safe_save(get_state_dict(scheduler), fn)
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
|
||||
tqdm.write("saving optim checkpoint")
|
||||
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
||||
safe_save(get_state_dict(scheduler), fn)
|
||||
|
||||
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
|
||||
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
||||
|
||||
@@ -184,7 +184,7 @@ class SMICtx:
|
||||
if compact: return {k: temps[k] for k in ("Hotspot", "HBM") if temps.get(k, 0) != 0}
|
||||
return {k: v for k, v in temps.items() if v != 0}
|
||||
case _:
|
||||
temps_keys = [(k, name) for k, name in dev.smu.smu_mod.c__EA_TEMP_e__enumvalues.items()
|
||||
temps_keys = [(k, name) for k, name in dev.smu.smu_mod.TEMP_e.items()
|
||||
if k < dev.smu.smu_mod.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0]
|
||||
if compact: temps_keys = [(k, name) for k, name in temps_keys if k in (dev.smu.smu_mod.TEMP_HOTSPOT, dev.smu.smu_mod.TEMP_MEM)]
|
||||
return {name: metrics.SmuMetrics.AvgTemperature[k] for k, name in temps_keys}
|
||||
@@ -193,7 +193,7 @@ class SMICtx:
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return {}
|
||||
case _:
|
||||
voltage_keys = [(k, name) for k, name in dev.smu.smu_mod.c__EA_SVI_PLANE_e__enumvalues.items()
|
||||
voltage_keys = [(k, name) for k, name in dev.smu.smu_mod.SVI_PLANE_e.items()
|
||||
if k < dev.smu.smu_mod.SVI_PLANE_COUNT and metrics.SmuMetrics.AvgVoltage[k] != 0]
|
||||
return {name: metrics.SmuMetrics.AvgVoltage[k] for k, name in voltage_keys}
|
||||
|
||||
|
||||
@@ -5,29 +5,7 @@ from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEven
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
from tinygrad.runtime.autogen import llvm, rocprof
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
# to pass NULL to callbacks
|
||||
llvm.LLVMCreateDisasmCPUFeatures.argtypes = tuple(llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5]) + (ctypes.c_void_p, ctypes.c_void_p)
|
||||
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
||||
llvm.LLVMInitializeAMDGPUTargetInfo()
|
||||
llvm.LLVMInitializeAMDGPUTargetMC()
|
||||
llvm.LLVMInitializeAMDGPUAsmParser()
|
||||
llvm.LLVMInitializeAMDGPUDisassembler()
|
||||
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, None, None)
|
||||
|
||||
image, sections, relocs = elf_loader(lib)
|
||||
text = next((sh.header for sh in sections if sh.name == ".text"), None)
|
||||
off, sz = unwrap(text).sh_addr, unwrap(text).sh_size
|
||||
|
||||
addr_table:dict[int, tuple[str, int]] = {}
|
||||
out = ctypes.create_string_buffer(128)
|
||||
cur_off = off
|
||||
while cur_off < sz + off:
|
||||
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
|
||||
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
|
||||
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
|
||||
cur_off += instr_sz
|
||||
return addr_table
|
||||
from tinygrad.viz.serve import llvm_disasm
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InstExec:
|
||||
|
||||
1
opencode.json
Normal file
1
opencode.json
Normal file
@@ -0,0 +1 @@
|
||||
{"$schema": "https://opencode.ai/config.json", "formatter": false}
|
||||
@@ -60,7 +60,7 @@ def universal_test(a, b, dtype, op):
|
||||
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
|
||||
tensor_value = (op[0](ta, tb)).numpy()
|
||||
numpy_value = op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7))
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
||||
@@ -77,8 +77,8 @@ def universal_test_unary(a, dtype, op):
|
||||
numpy_value = op[1](ta.numpy())
|
||||
if dtype in dtypes.fp8s:
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
if math.isinf(numpy_value): return
|
||||
numpy_value = truncate[dtype](numpy_value)
|
||||
if math.isinf(numpy_value.item()): return
|
||||
numpy_value = truncate[dtype](numpy_value.item())
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
|
||||
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
|
||||
|
||||
172
test/testextra/test_cfg_viz.py
Normal file
172
test/testextra/test_cfg_viz.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import unittest
|
||||
import textwrap
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.uop.ops import UOp, Ops, track_rewrites
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.helpers import TracingKey
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
|
||||
# TODO: use the RDNA3 renderer when it's in master
|
||||
template = """.text
|
||||
.globl fn_name
|
||||
.p2align 8
|
||||
.type fn_name,@function
|
||||
fn_name:
|
||||
INSTRUCTION
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel fn_name
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
||||
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
||||
.amdhsa_wavefront_size32 1
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: fn_name
|
||||
.symbol: fn_name.kd
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 8
|
||||
.max_flat_workgroup_size: 1024
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 8
|
||||
.args:
|
||||
- .address_space: global
|
||||
.name: a
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.type_name: 'float*'
|
||||
.value_kind: global_buffer
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
||||
def run_asm(name:str, src:str) -> ProgramSpec:
|
||||
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK))
|
||||
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
||||
ei.run()
|
||||
return prg
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
|
||||
class TestCfg(unittest.TestCase):
|
||||
def setUp(self):
|
||||
arch = Device["AMD"].arch
|
||||
if not any(arch.startswith(a) for a in {"gfx11", "gfx12"}):
|
||||
self.skipTest(f"tests written for RDNA, got arch {arch}")
|
||||
|
||||
def test_simple(self):
|
||||
run_asm("simple", """
|
||||
entry:
|
||||
s_branch bb1
|
||||
bb1:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_diamond(self):
|
||||
run_asm("diamond", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 if
|
||||
s_branch else
|
||||
if:
|
||||
s_nop 1
|
||||
s_branch end
|
||||
else:
|
||||
s_nop 0
|
||||
end:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_loop(self):
|
||||
run_asm("simple_loop", """
|
||||
entry:
|
||||
s_mov_b32 s1, 4
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_loop_branch(self):
|
||||
run_asm("loop_if", """
|
||||
entry:
|
||||
s_mov_b32 s1, 4
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 2
|
||||
s_cbranch_scc1 cond
|
||||
s_branch cont
|
||||
cond:
|
||||
s_add_u32 s1, s1, -2
|
||||
cont:
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_loop_break(self):
|
||||
run_asm("loop_break", """
|
||||
entry:
|
||||
s_mov_b32 s1, 8
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 5
|
||||
s_cbranch_scc1 break
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
break:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_switch(self):
|
||||
run_asm("switch_case", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 case0
|
||||
s_cmp_eq_i32 s0, 1
|
||||
s_cbranch_scc1 case1
|
||||
s_branch case2
|
||||
case0:
|
||||
s_nop 0
|
||||
s_branch join
|
||||
case1:
|
||||
s_nop 1
|
||||
s_branch join
|
||||
case2:
|
||||
s_nop 2
|
||||
s_branch join
|
||||
join:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_ping_pong(self):
|
||||
run_asm("ping_pong", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 ping
|
||||
s_branch pong
|
||||
ping:
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc1 pong
|
||||
s_branch end
|
||||
pong:
|
||||
s_cmp_eq_i32 s2, 0
|
||||
s_cbranch_scc1 ping
|
||||
end:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -66,5 +66,21 @@ class TestDtypeTolist(unittest.TestCase):
|
||||
# 57344
|
||||
self.assertEqual(Tensor([-30000, 1.5, 3.1, 30000], device="PYTHON", dtype=dtypes.fp8e5m2).tolist(), [-28672.0, 1.5, 3.0, 28672.0])
|
||||
|
||||
class TestCanLosslessCast(unittest.TestCase):
|
||||
def test_can_lossless_cast(self):
|
||||
from tinygrad.dtype import can_lossless_cast
|
||||
# signed -> unsigned is NOT lossless (negative values wrap)
|
||||
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.uint64))
|
||||
self.assertFalse(can_lossless_cast(dtypes.int32, dtypes.uint32))
|
||||
# unsigned -> larger signed is lossless
|
||||
self.assertTrue(can_lossless_cast(dtypes.uint8, dtypes.int16))
|
||||
self.assertTrue(can_lossless_cast(dtypes.uint32, dtypes.int64))
|
||||
# large ints don't fit in floats
|
||||
self.assertFalse(can_lossless_cast(dtypes.int32, dtypes.float))
|
||||
self.assertFalse(can_lossless_cast(dtypes.int64, dtypes.double))
|
||||
# half has more mantissa bits
|
||||
self.assertTrue(can_lossless_cast(dtypes.int8, dtypes.half))
|
||||
self.assertFalse(can_lossless_cast(dtypes.int8, dtypes.bfloat16))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
53
test/unit/test_llm_moe.py
Normal file
53
test/unit/test_llm_moe.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestMoEFeedForward(unittest.TestCase):
|
||||
def test_moe_feed_forward(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
|
||||
# set up weights: gate scales by (expert_id+1), up/down are identity-ish, router picks experts 0,2
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
|
||||
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
|
||||
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T # router strongly prefers experts 0 and 2
|
||||
block.ffn_norm.weight = Tensor.ones(dim) # identity norm
|
||||
|
||||
# input of ones -> after norm still ~ones -> experts 0,2 selected -> weighted sum of silu outputs
|
||||
h = Tensor.ones(1, 1, dim)
|
||||
out = block._feed_forward(h)
|
||||
|
||||
# expected: residual + moe_output ≈ 1 + avg(silu(1), silu(3))
|
||||
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
|
||||
|
||||
def test_moe_feed_forward_batched(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
|
||||
# same setup as BS=1 test
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
|
||||
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
|
||||
block.ffn_gate_inp.weight = Tensor([[1, 0, 1, 0]] * dim).T
|
||||
block.ffn_norm.weight = Tensor.ones(dim)
|
||||
|
||||
# test with BS=2, T=3
|
||||
h = Tensor.ones(2, 3, dim)
|
||||
out = block._feed_forward(h)
|
||||
|
||||
# all outputs should match the BS=1 expected value
|
||||
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -95,21 +95,24 @@ class TestUOpResolve(unittest.TestCase):
|
||||
x = UOp.variable("i", 1, 10)
|
||||
self.assertFalse(x < x)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_x_lt_xp1(self):
|
||||
x = UOp.variable("i", 1, 10)
|
||||
self.assertTrue(x < (x+1))
|
||||
u = x < (x+1)
|
||||
# TODO: improve
|
||||
with self.assertRaises(ValueError):
|
||||
bool(u)
|
||||
|
||||
def test_and_true(self):
|
||||
u = UOp.variable("b", False, True, dtypes.bool) & True
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.variable("b", False, True, dtypes.bool) & True
|
||||
self.assertFalse(u)
|
||||
bool(u)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_var_cmp_range(self):
|
||||
v = UOp.variable("i", 1, 10)
|
||||
u = (v > 4) | (v < 6)
|
||||
self.assertTrue(u)
|
||||
# TODO: improve
|
||||
with self.assertRaises(ValueError):
|
||||
bool(u)
|
||||
|
||||
def test_var_cmp_assert(self):
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
const { spawn } = require("child_process");
|
||||
const puppeteer = require("puppeteer");
|
||||
|
||||
async function main() {
|
||||
// ** start viz server
|
||||
const proc = spawn("python", ["-u", "-c", "from tinygrad import Tensor; Tensor.arange(4).realize()"], { env: { ...process.env, VIZ:"1" },
|
||||
stdio: ["inherit", "pipe", "inherit"]});
|
||||
await new Promise(resolve => proc.stdout.on("data", r => {
|
||||
if (r.includes("ready")) resolve();
|
||||
}));
|
||||
|
||||
// ** run browser tests
|
||||
let browser, page;
|
||||
try {
|
||||
browser = await puppeteer.launch({ headless: true });
|
||||
page = await browser.newPage();
|
||||
const res = await page.goto("http://localhost:8000", { waitUntil:"domcontentloaded" });
|
||||
if (res.status() !== 200) throw new Error("Failed to load page");
|
||||
const scheduleSelector = await page.waitForSelector("ul:nth-of-type(2)");
|
||||
scheduleSelector.click();
|
||||
await page.waitForSelector("rect");
|
||||
await page.waitForFunction(() => {
|
||||
const nodes = document.querySelectorAll("#nodes > g").length;
|
||||
const edges = document.querySelectorAll("#edges > path").length;
|
||||
return nodes > 0 && edges > 0;
|
||||
});
|
||||
} finally {
|
||||
// ** cleanups
|
||||
if (page != null) await page.close();
|
||||
if (browser != null) await browser.close();
|
||||
proc.kill();
|
||||
}
|
||||
}
|
||||
|
||||
main();
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler,
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
if preset not in ("llama3","llama-v3","llama-bpe","qwen2","olmo"): raise ValueError(f"Invalid tokenizer preset '{preset}'")
|
||||
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||
bs = [*range(33, 127), *range(161, 173), *range(174, 256)] # bytes that map to themselves
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
@@ -52,9 +52,13 @@ class SimpleTokenizer:
|
||||
|
||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode(errors='replace')
|
||||
def role(self, role:str):
|
||||
if self.preset == 'olmo': return self.encode("<|" + role + "|>\n") # OLMoE Instruct format
|
||||
if self.preset == 'qwen2': return self.encode("<|im_start|>" + role + "\n")
|
||||
return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
def end_turn(self, eos_id:int): return [eos_id] + self.encode("\n") if self.preset == 'qwen2' else [eos_id]
|
||||
def end_turn(self, eos_id:int):
|
||||
if self.preset == 'olmo': return self.encode("\n")
|
||||
if self.preset == 'qwen2': return [eos_id] + self.encode("\n")
|
||||
return [eos_id]
|
||||
|
||||
@functools.cache
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
@@ -62,6 +66,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return freqs.cos().cat(freqs.sin(), dim=-1).contiguous()
|
||||
|
||||
class ExpertWeights:
|
||||
"""Like nn.Linear but with num_experts dimension. Weight shape: (num_experts, out_features, in_features)."""
|
||||
def __init__(self, num_experts:int, in_features:int, out_features:int):
|
||||
self.weight = Tensor.zeros(num_experts, out_features, in_features)
|
||||
def __call__(self, sel:Tensor, x:Tensor) -> Tensor:
|
||||
# sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out)
|
||||
return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).squeeze(-2)
|
||||
|
||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
assert x.shape[-1] % 2 == 0
|
||||
cos, sin = freqs_cis.reshape(1, 1, x.shape[2], -1).chunk(2, dim=-1)
|
||||
@@ -70,12 +82,13 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:bool=False):
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.max_context = max_context
|
||||
self.rope_theta = rope_theta
|
||||
self.max_context = max_context
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
# --- attention projections (all linear, bias-free) ------------------
|
||||
q_proj_out = self.head_dim * n_heads
|
||||
@@ -88,23 +101,30 @@ class TransformerBlock:
|
||||
# --- RMSNorms --------------------------------------------------------
|
||||
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(self.head_dim, norm_eps), nn.RMSNorm(self.head_dim, norm_eps)
|
||||
if qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(qk_norm, norm_eps), nn.RMSNorm(qk_norm, norm_eps)
|
||||
|
||||
# --- feed-forward ----------------------------------------------------
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if num_experts > 0:
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
|
||||
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
self.ffn_up_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
self.ffn_down_exps = ExpertWeights(num_experts, hidden_dim, dim)
|
||||
else:
|
||||
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||
if self.qk_norm and self.qk_norm != self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
B, T, _ = x.shape
|
||||
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
|
||||
if hasattr(self, 'attn_q_norm'): q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
# TODO: make UOp have SupportsIndex
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] # type: ignore
|
||||
@@ -127,6 +147,11 @@ class TransformerBlock:
|
||||
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
|
||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
|
||||
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||
# TODO: remove the need for this contiguous
|
||||
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
|
||||
return h + self.ffn_down(gated)
|
||||
@@ -136,9 +161,9 @@ class TransformerBlock:
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:bool=False):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm)
|
||||
for _ in range(num_blocks)]
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
||||
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
|
||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
@@ -170,16 +195,20 @@ class Transformer:
|
||||
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||||
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
|
||||
|
||||
# permute Q/K weights from interleaved to half-split RoPE layout: [0,1,2,3,4,5...] -> [0,2,4,...,1,3,5,...]
|
||||
if arch != 'qwen3':
|
||||
# Permute Q/K weights from interleaved to half-split RoPE layout (llama-style models only)
|
||||
if arch == 'llama':
|
||||
for name in state_dict:
|
||||
if 'attn_q.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_heads, two=2)
|
||||
if 'attn_k.weight' in name: state_dict[name] = state_dict[name].rearrange("(n h two) d -> (n two h) d", n=n_kv_heads, two=2)
|
||||
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
||||
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
|
||||
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
|
||||
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']), head_dim=kv[f'{arch}.attention.key_length'],
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm='blk.0.attn_q_norm.weight' in state_dict)
|
||||
vocab_size=len(kv['tokenizer.ggml.tokens']),
|
||||
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
@@ -207,6 +236,8 @@ models = {
|
||||
"qwen3:0.6b": "https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf",
|
||||
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
|
||||
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
|
||||
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
|
||||
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
|
||||
}
|
||||
|
||||
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
||||
|
||||
@@ -219,17 +219,19 @@ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType)
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void", "index":"index"}
|
||||
|
||||
@functools.cache
|
||||
def can_safe_cast(dt0:DType, dt1:DType) -> bool:
|
||||
def can_lossless_cast(dt0:DType, dt1:DType) -> bool:
|
||||
# return if dt1 preserves value of dt0
|
||||
# https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
||||
# similar to https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
|
||||
if dt0 == dt1 or dt0 == dtypes.bool: return True
|
||||
match dt1:
|
||||
case dtypes.index: return dt0 in dtypes.ints
|
||||
case dtypes.double: return dt0 in (dtypes.float, dtypes.half, dtypes.bfloat16, *dtypes.fp8s,
|
||||
dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
||||
case dtypes.float: return dt0 in (dtypes.half, dtypes.bfloat16, *dtypes.fp8s, dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
||||
case dtypes.half: return dt0 in (*dtypes.fp8s, dtypes.uint8, dtypes.int8)
|
||||
case dtypes.uint64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8)
|
||||
case dtypes.uint32: return dt0 in (dtypes.uint16, dtypes.uint8)
|
||||
case dtypes.uint16: return dt0 in (dtypes.uint8,)
|
||||
case dtypes.int64: return dt0 in (dtypes.uint32, dtypes.uint16, dtypes.uint8, dtypes.int32, dtypes.int16, dtypes.int8)
|
||||
case dtypes.int32: return dt0 in (dtypes.uint16, dtypes.uint8, dtypes.int16, dtypes.int8)
|
||||
case dtypes.int16: return dt0 in (dtypes.uint8, dtypes.int8)
|
||||
@@ -320,11 +322,8 @@ def fp8_to_float(x: int, dtype: DType) -> float:
|
||||
truncate: dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.float16: float_to_fp16, dtypes.bfloat16: lambda x: float_to_bf16(float(x)),
|
||||
**{fp8: (lambda x, dtype=fp8: fp8_to_float(float_to_fp8(x, dtype), dtype)) for fp8 in dtypes.fp8s},
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
|
||||
dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
||||
**{getattr(dtypes, n): (lambda x, c=getattr(ctypes, f'c_{n}'): c(x).value)
|
||||
for n in ('float', 'double', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64')}}
|
||||
|
||||
# numpy and torch dtype interop
|
||||
|
||||
|
||||
@@ -20,8 +20,9 @@ from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_so
|
||||
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta, PCIDevice, USBPCIDevice, MAP_FIXED, MAP_NORESERVE
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
SQTT, SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT", VIZ.value>=2), ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
|
||||
PMC = ContextVar("PMC", VIZ.value>=2)
|
||||
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
|
||||
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
|
||||
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
|
||||
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
|
||||
WAIT_REG_MEM_FUNCTION_NEQ = 4 # !=
|
||||
@@ -829,8 +830,8 @@ class PCIIface(PCIIfaceBase):
|
||||
doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0), pipe=0, queue=0)
|
||||
else:
|
||||
pv = self.dev_impl.gfx.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr, wptr_addr=gart.va_addr+wptr,
|
||||
eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_MEC_RING0), pipe=0, queue=0,
|
||||
aql=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL))
|
||||
eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_MEC_RING0), pipe=0,
|
||||
queue=int(is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL)), aql=is_aql)
|
||||
|
||||
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
|
||||
read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv)
|
||||
|
||||
@@ -142,7 +142,7 @@ class AMMemoryManager(MemoryManager):
|
||||
self.dev.gmc.flush_tlb(ip='MM', vmid=0)
|
||||
|
||||
class AMDev(PCIDevImplBase):
|
||||
Version = 0xA0000006
|
||||
Version = 0xA0000007
|
||||
|
||||
def __init__(self, pci_dev:PCIDevice, dma_regions:list[tuple[int, MMIOInterface]]|None=None, reset_mode=False):
|
||||
self.pci_dev, self.devfmt, self.dma_regions = pci_dev, pci_dev.pcibus, dma_regions
|
||||
|
||||
@@ -212,14 +212,19 @@ class AM_SMU(AM_IP):
|
||||
return (self.adev.mmMP1_SMN_C2PMSG_82 if not debug else self.adev.mmMP1_SMN_C2PMSG_53).read() if read_back_arg else None
|
||||
|
||||
class AM_GFX(AM_IP):
|
||||
def init_sw(self): self.xccs = len(self.adev.regs_offset[am.GC_HWIP])
|
||||
def init_sw(self):
|
||||
self.xccs = len(self.adev.regs_offset[am.GC_HWIP])
|
||||
self.mqd_paddr = [self.adev.mm.palloc(0x1000 * self.xccs, zero=False, boot=True) for i in range(2)]
|
||||
self.mqd_mc = [self.adev.paddr2mc(mqd_paddr) for mqd_paddr in self.mqd_paddr]
|
||||
|
||||
def init_hw(self):
|
||||
# Wait for RLC autoload to complete
|
||||
while self.adev.regCP_STAT.read() != 0 and self.adev.regRLC_RLCS_BOOTLOAD_STATUS.read_bitfields()['bootload_complete'] != 0: pass
|
||||
|
||||
self._config_gfx_rs64()
|
||||
self.adev.gmc.init_hub("GC", inst_cnt=self.xccs)
|
||||
if self.adev.partial_boot: return
|
||||
|
||||
self._config_gfx_rs64()
|
||||
|
||||
# NOTE: Golden reg for gfx11. No values for this reg provided. The kernel just ors 0x20000000 to this reg.
|
||||
for xcc in range(self.xccs): self.adev.regTCP_CNTL.write(self.adev.regTCP_CNTL.read() | 0x20000000, inst=xcc)
|
||||
@@ -265,49 +270,61 @@ class AM_GFX(AM_IP):
|
||||
if self.xccs > 1 and not self.adev.partial_boot: self.adev.psp._spatial_partition_cmd(1)
|
||||
|
||||
def fini_hw(self):
|
||||
for xcc in range(self.xccs):
|
||||
self._grbm_select(me=1, pipe=0, queue=0, inst=xcc)
|
||||
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1: self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
||||
self._grbm_select(inst=xcc)
|
||||
# NOTE: For aqls with xccs (queue=1), will continue from the saved state.
|
||||
for q in range(2 if self.xccs == 1 else 1):
|
||||
for xcc in range(self.xccs):
|
||||
self._grbm_select(me=1, pipe=0, queue=q, inst=xcc)
|
||||
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1: self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
||||
self._grbm_select(inst=xcc)
|
||||
for xcc in range(self.xccs): self.adev.regGCVM_CONTEXT0_CNTL.write(0, inst=xcc)
|
||||
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int,
|
||||
aql:bool) -> int:
|
||||
for xcc in range(self.xccs if aql else 1):
|
||||
mqd = self.adev.mm.valloc(0x1000, uncached=True, contiguous=True)
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=0)
|
||||
restore_queue = aql and self.xccs > 1 and self.adev.partial_boot and (self.adev.regCP_HQD_ACTIVE.read(inst=0) & 1)
|
||||
restore_ptr = (self.adev.regCP_HQD_PQ_WPTR_LO.read(inst=0) | (self.adev.regCP_HQD_PQ_WPTR_HI.read(inst=0) << 32)) if restore_queue else 0
|
||||
if DEBUG >= 2 and restore_queue: print(f"am {self.adev.devfmt}: GFX queue already active, continuing from saved state {restore_ptr=:#x}.")
|
||||
|
||||
for xcc in range(self.xccs if aql else 1):
|
||||
struct_t = getattr(am, f"struct_v{self.adev.ip_ver[am.GC_HWIP][0]}{'_compute' if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else ''}_mqd")
|
||||
mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(mqd.va_addr), cp_mqd_base_addr_hi=hi32(mqd.va_addr),
|
||||
mqd_struct = struct_t(header=0xC0310800, cp_mqd_base_addr_lo=lo32(self.mqd_mc[queue] + 0x1000*xcc),
|
||||
cp_mqd_base_addr_hi=hi32(self.mqd_mc[queue] + 0x1000*xcc), cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
|
||||
cp_hqd_persistent_state=self.adev.regCP_HQD_PERSISTENT_STATE.encode(preload_size=0x55, preload_req=1),
|
||||
cp_hqd_pipe_priority=0x2, cp_hqd_queue_priority=0xf, cp_hqd_quantum=0x111,
|
||||
cp_hqd_pq_base_lo=lo32(ring_addr>>8), cp_hqd_pq_base_hi=hi32(ring_addr>>8),
|
||||
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
|
||||
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
|
||||
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.encode(doorbell_offset=doorbell*2, doorbell_en=1),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.encode(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2,
|
||||
**({'queue_full_en':1, 'slot_based_wptr':2, 'wpp_clamp_en':1, 'no_update_rptr':xcc!=0 or self.xccs==1} if aql else {})),
|
||||
**({'queue_full_en':1, 'slot_based_wptr':2, 'no_update_rptr':xcc!=0 or self.xccs==1} if aql else {})),
|
||||
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.encode(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
||||
cp_mqd_control=self.adev.regCP_MQD_CONTROL.encode(priv_state=1), cp_hqd_vmid=0, cp_hqd_aql_control=int(aql),
|
||||
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
|
||||
cp_hqd_eop_control=self.adev.regCP_HQD_EOP_CONTROL.encode(eop_size=(eop_size//4).bit_length()-2),
|
||||
**({'compute_tg_chunk_size':1, 'compute_current_logic_xcc_id':xcc} if aql and self.xccs > 1 else {}))
|
||||
**({'compute_tg_chunk_size':1, 'compute_current_logic_xcc_id':xcc, 'cp_mqd_stride_size':0x1000} if aql and self.xccs > 1 else {}))
|
||||
for se in range(8 if self.adev.ip_ver[am.GC_HWIP][0] >= 10 else 4): setattr(mqd_struct, f'compute_static_thread_mgmt_se{se}', 0xffffffff)
|
||||
|
||||
# Copy mqd into memory
|
||||
self.adev.vram.view(mqd.paddrs[0][0], ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
|
||||
self.adev.gmc.flush_hdp()
|
||||
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=xcc)
|
||||
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc)
|
||||
if restore_queue:
|
||||
for r in [self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR, self.adev.regCP_HQD_EOP_BASE_ADDR_HI,
|
||||
self.adev.regCP_HQD_PQ_RPTR_REPORT_ADDR_HI, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR, self.adev.regCP_HQD_PQ_WPTR_POLL_ADDR_HI]:
|
||||
val = memoryview(bytes(mqd_struct)).cast('I')[0x80 + (off:=r.addr[xcc] - self.adev.regCP_MQD_BASE_ADDR.addr[xcc])]
|
||||
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct), fmt='I')[0x80 + off] = val
|
||||
r.write(val, inst=xcc)
|
||||
else:
|
||||
self.adev.vram.view(self.mqd_paddr[queue] + 0x1000*xcc, ctypes.sizeof(mqd_struct))[:] = memoryview(mqd_struct).cast('B')
|
||||
|
||||
mqd_st_mv = to_mv(ctypes.addressof(mqd_struct), ctypes.sizeof(mqd_struct)).cast('I')
|
||||
for i, reg in enumerate(range(self.adev.regCP_MQD_BASE_ADDR.addr[xcc], self.adev.regCP_HQD_PQ_WPTR_HI.addr[xcc] + 1)):
|
||||
self.adev.wreg(reg, mqd_st_mv[0x80 + i])
|
||||
self.adev.regCP_HQD_ACTIVE.write(0x1, inst=xcc)
|
||||
|
||||
self.adev.gmc.flush_hdp()
|
||||
self._grbm_select(inst=xcc)
|
||||
|
||||
self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1, inst=xcc)
|
||||
return 0
|
||||
return restore_ptr // 16
|
||||
|
||||
def set_clockgating_state(self):
|
||||
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
|
||||
@@ -476,7 +493,7 @@ class AM_PSP(AM_IP):
|
||||
|
||||
if not self.is_sos_alive():
|
||||
for fw, compid in sos_components: self._bootloader_load_component(fw, compid)
|
||||
while not self.is_sos_alive(): time.sleep(0.01)
|
||||
wait_cond(self.is_sos_alive, value=True, msg="sOS failed to start")
|
||||
|
||||
self._ring_create()
|
||||
if am.PSP_FW_TYPE_PSP_TOC in self.adev.fw.sos_fw: self._tmr_init()
|
||||
|
||||
@@ -1174,7 +1174,7 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
|
||||
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
|
||||
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
|
||||
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
||||
if VIZ > 0: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
|
||||
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
@@ -68,6 +68,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
||||
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
||||
(UPat.var("x") ^ UPat.var("x"), lambda x: x.const_like(0)), # x^x -> 0
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# ** constant folding **
|
||||
@@ -97,7 +98,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
(UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast),
|
||||
# b.cast(a).cast(b) -> b if a preserves all values in b
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_safe_cast(b.dtype, a.dtype) else None),
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None),
|
||||
# ** pow **
|
||||
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
|
||||
# positive const ** x
|
||||
@@ -242,7 +243,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
(UPat(Ops.RANGE, src=UPat.var("end"), name="r")//UPat.var("end"), lambda r,end: r.const_like(0)),
|
||||
# cast/long folding
|
||||
# if the intermediate cast doesnt narrow we can do it in one cast
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_safe_cast(x.dtype, a.dtype) else None),
|
||||
(UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x.cast(b.dtype) if can_lossless_cast(x.dtype, a.dtype) else None),
|
||||
(UPat.var('x', dtypes.ints+(dtypes.index,)).cast(dtypes.ints+(dtypes.index,), name="a").cast(name="b"),
|
||||
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
|
||||
# try to do math in int instead of long
|
||||
|
||||
@@ -124,7 +124,6 @@
|
||||
fill: rgba(26, 27, 38, 0.5);
|
||||
}
|
||||
.edgePath {
|
||||
stroke: #4a4b57;
|
||||
fill: none;
|
||||
stroke-width: 1.4px;
|
||||
}
|
||||
@@ -322,14 +321,6 @@
|
||||
td.Instruction {
|
||||
font-family: monospace;
|
||||
}
|
||||
td.pct-row > div {
|
||||
height: 12px;
|
||||
width: 100%;
|
||||
display: flex;
|
||||
}
|
||||
td.pct-row > div > div {
|
||||
height: 100%;
|
||||
}
|
||||
thead {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
@@ -350,12 +341,11 @@
|
||||
tr.nested-row table tr.main-row:hover {
|
||||
background-color: unset;
|
||||
}
|
||||
tr.main-row.has-children > td:first-child {
|
||||
white-space: pre;
|
||||
tr.main-row.has-children > td:first-child > p {
|
||||
display: inline-block;
|
||||
}
|
||||
tr.main-row.has-children > td:first-child::before {
|
||||
content: "▸ ";
|
||||
display: inline-block;
|
||||
width: 1em;
|
||||
margin-left: -0.25em;
|
||||
}
|
||||
|
||||
@@ -26,18 +26,17 @@ const colored = n => d3.create("span").call(s => s.selectAll("span").data(typeof
|
||||
const rect = (s) => (typeof s === "string" ? document.querySelector(s) : s).getBoundingClientRect();
|
||||
|
||||
let timeout = null;
|
||||
const updateProgress = ({ start, err }) => {
|
||||
const Status = {STARTED:0, COMPLETE:1, ERR:2}
|
||||
const updateProgress = (st, msg) => {
|
||||
clearTimeout(timeout);
|
||||
const msg = document.getElementById("progress-message");
|
||||
msg.style.display = "none";
|
||||
if (start) {
|
||||
msg.innerText = "Rendering new graph...";
|
||||
timeout = setTimeout(() => { msg.style.display = "block"; }, 2000);
|
||||
}
|
||||
d3.select("#custom").html("");
|
||||
if (err) {
|
||||
const msgEl = d3.select("#progress-message").style("display", "none");
|
||||
const customEl = d3.select("#custom").html("");
|
||||
if (st === Status.STARTED) {
|
||||
msgEl.text(msg);
|
||||
timeout = setTimeout(() => msgEl.style("display", "block"), 2000);
|
||||
} else if (st === Status.ERR) {
|
||||
displaySelection("#custom");
|
||||
d3.select("#custom").append("div").classed("raw-text", true).call(s => s.append(() => codeBlock(err, "txt"))).node();
|
||||
customEl.append("div").classed("raw-text", true).append(() => codeBlock(msg));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,23 +76,19 @@ const drawGraph = (data) => {
|
||||
.attr("x", d => -d.width/2).attr("y", d => -d.height/2);
|
||||
const STROKE_WIDTH = 1.4;
|
||||
const labels = nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label");
|
||||
const hasLabelDims = data.nodes[0]?.value.labelWidth != null;
|
||||
if (hasLabelDims) labels.attr("transform", d => `translate(-${d.labelWidth/2}, -${d.labelHeight/2+STROKE_WIDTH*2})`);
|
||||
labels.attr("transform", d => `translate(-${d.labelWidth/2}, -${d.labelHeight/2+STROKE_WIDTH*2})`);
|
||||
labels.selectAll("text").data(d => {
|
||||
if (Array.isArray(d.label)) return [d.label];
|
||||
const ret = [[]];
|
||||
for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
|
||||
const lines = st.split("\n");
|
||||
for (const s of parseColors(d.label, defaultColor="initial")) {
|
||||
const color = darkenHex(s.color, 25);
|
||||
const lines = s.st.split("\n");
|
||||
ret.at(-1).push({ st:lines[0], color });
|
||||
for (let i=1; i<lines.length; i++) ret.push([{ st:lines[i], color }]);
|
||||
}
|
||||
return [ret];
|
||||
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
|
||||
.attr("fill", d => darkenHex(d.color, 25)).text(d => d.st).attr("xml:space", "preserve");
|
||||
// recenter after drawing texts if needed
|
||||
if (!hasLabelDims) labels.attr("transform", (_,i,els) => {
|
||||
const b = els[i].getBBox();
|
||||
return `translate(${-b.x-b.width/2}, ${-b.y-b.height/2})`
|
||||
});
|
||||
.attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve").style("font-family", g.graph().font);
|
||||
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
|
||||
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
|
||||
// draw edges
|
||||
@@ -104,7 +99,7 @@ const drawGraph = (data) => {
|
||||
points.unshift(intersectRect(g.node(e.v), points[0]));
|
||||
points.push(intersectRect(g.node(e.w), points[points.length-1]));
|
||||
return line(points);
|
||||
}).attr("marker-end", "url(#arrowhead)");
|
||||
}).attr("marker-end", "url(#arrowhead)").attr("stroke", e => g.edge(e).color || "#4a4b57");
|
||||
}
|
||||
|
||||
// ** UOp graph
|
||||
@@ -115,15 +110,15 @@ async function initWorker() {
|
||||
workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" }));
|
||||
}
|
||||
|
||||
function renderDag(graph, additions, recenter, layoutOpts) {
|
||||
function renderDag(layoutSpec, { recenter }) {
|
||||
// start calculating the new layout (non-blocking)
|
||||
updateProgress({ start:true });
|
||||
updateProgress(Status.STARTED, "Rendering new graph...");
|
||||
if (worker != null) worker.terminate();
|
||||
worker = new Worker(workerUrl);
|
||||
worker.postMessage({graph, additions, opts:layoutOpts });
|
||||
worker.postMessage(layoutSpec);
|
||||
worker.onmessage = (e) => {
|
||||
displaySelection("#graph");
|
||||
updateProgress({ start:false });
|
||||
updateProgress(Status.COMPLETE);
|
||||
drawGraph(e.data);
|
||||
addTags(d3.select("#edge-labels").selectAll("g").data(e.data.edges).join("g").attr("transform", (e) => {
|
||||
// get a point near the end
|
||||
@@ -144,7 +139,7 @@ function renderDag(graph, additions, recenter, layoutOpts) {
|
||||
};
|
||||
worker.onerror = (e) => {
|
||||
e.preventDefault();
|
||||
updateProgress({ err:"Error in graph layout:\n"+e.message });
|
||||
updateProgress(Status.ERR, "Error in graph layout:\n"+e.message);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,8 +154,7 @@ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
|
||||
|
||||
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
|
||||
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
|
||||
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SE:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
|
||||
CATEGORICAL:["#ff8080", "#F4A261", "#C8F9D4", "#8D99AE", "#F4A261", "#ffffa2", "#ffffc0", "#87CEEB"],}
|
||||
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SE:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),}
|
||||
const cycleColors = (lst, i) => lst[i%lst.length];
|
||||
|
||||
const rescaleTrack = (source, tid, k) => {
|
||||
@@ -257,7 +251,7 @@ async function renderProfiler(path, unit, opts) {
|
||||
if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; }
|
||||
metadata.replaceChildren(getMetadata(focusedShape));
|
||||
// layout once!
|
||||
if (data.tracks.size !== 0) return updateProgress({ start:false });
|
||||
if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE);
|
||||
const profiler = d3.select("#profiler").html("");
|
||||
const buf = cache[path] ?? await fetchValue(path);
|
||||
const view = new DataView(buf);
|
||||
@@ -419,7 +413,7 @@ async function renderProfiler(path, unit, opts) {
|
||||
}
|
||||
}
|
||||
for (const m of markers) m.label = m.name.split(/(\s+)/).map(st => ({ st, color:m.color, width:ctx.measureText(st).width }));
|
||||
updateProgress({ start:false });
|
||||
updateProgress(Status.COMPLETE);
|
||||
// draw events on a timeline
|
||||
const dpr = window.devicePixelRatio || 1;
|
||||
const ellipsisWidth = ctx.measureText("...").width;
|
||||
@@ -604,7 +598,7 @@ const pathLink = (fp, lineno) => d3.create("a").attr("href", "vscode://file/"+fp
|
||||
function codeBlock(st, language, { loc, wrap }={}) {
|
||||
const code = document.createElement("code");
|
||||
// plaintext renders like a terminal print, otherwise render with syntax highlighting
|
||||
if (language === "txt") code.appendChild(colored(st));
|
||||
if (!language || language === "txt") code.appendChild(colored(st));
|
||||
else code.innerHTML = hljs.highlight(st, { language }).value;
|
||||
code.className = "hljs";
|
||||
const ret = document.createElement("pre");
|
||||
@@ -636,9 +630,6 @@ hljs.registerLanguage("cpp", (hljs) => ({
|
||||
...hljs.getLanguage('cpp'),
|
||||
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
|
||||
}));
|
||||
hljs.registerLanguage("amdgpu", (hljs) => ({
|
||||
contains: [hljs.COMMENT("//", "$"), { begin:/\b(?:s_|v_|global_|buffer_|scratch_|flat_|ds_)[a-z0-9_]*\b/, className:"code" }]
|
||||
}));
|
||||
|
||||
async function fetchValue(path) {
|
||||
const res = await fetch(path);
|
||||
@@ -800,25 +791,19 @@ async function main() {
|
||||
}
|
||||
const td = tr.append("td").classed(ret.cols[i], true);
|
||||
// string format scalar values
|
||||
if (!Array.isArray(value)) { td.text(typeof value === "string" ? value : ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)); continue; }
|
||||
// display arrays in a bar graph
|
||||
td.classed("pct-row", true);
|
||||
const bar = td.append("div");
|
||||
value.forEach(([k, v, width]) => bar.append("div").style("width", width+"%").attr("title", `${ret.cols[i].labels[k]} ${v}`)
|
||||
.style("background", cycleColors(colorScheme.CATEGORICAL, parseInt(k))))
|
||||
td.append(() => typeof value === "string" ? colored(value) : d3.create("p").text(ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)).node());
|
||||
}
|
||||
}
|
||||
return table;
|
||||
}
|
||||
if (ret.cols != null) {
|
||||
renderTable(root, ret);
|
||||
} else root.append(() => codeBlock(ret.src, ret.lang || "txt"));
|
||||
if (ret.cols != null) renderTable(root, ret);
|
||||
else if (ret.data != null) renderDag(ret, { recenter:true });
|
||||
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
|
||||
ret.metadata?.forEach(m => {
|
||||
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value, idx }) => {
|
||||
const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, idx)).style("width", "100%").style("height", "100%");
|
||||
return [label.trim(), div.text(typeof value === "string" ? value : formatUnit(value)).node()];
|
||||
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
|
||||
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
|
||||
})).node());
|
||||
metadata.appendChild(codeBlock(m.src, "txt")).classList.add("full-height")
|
||||
metadata.appendChild(codeBlock(m.src)).classList.add("full-height")
|
||||
});
|
||||
return document.querySelector("#custom").replaceChildren(root.node());
|
||||
}
|
||||
@@ -843,7 +828,7 @@ async function main() {
|
||||
if (ret.length === 0) return;
|
||||
// ** center graph
|
||||
const data = ret[currentRewrite];
|
||||
const render = (opts) => renderDag(data.graph, data.changed_nodes ?? [], currentRewrite === 0, opts);
|
||||
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
||||
render({ showIndexing:toggle.checked });
|
||||
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
|
||||
// ** right sidebar metadata
|
||||
|
||||
@@ -1,27 +1,57 @@
|
||||
const NODE_PADDING = 10;
|
||||
const rectDims = (lw, lh) => ({ width:lw+NODE_PADDING*2, height:lh+NODE_PADDING*2, labelWidth:lw, labelHeight:lh });
|
||||
|
||||
const LINE_HEIGHT = 14;
|
||||
const canvas = new OffscreenCanvas(0, 0);
|
||||
const ctx = canvas.getContext("2d");
|
||||
ctx.font = `350 ${LINE_HEIGHT}px sans-serif`;
|
||||
|
||||
onmessage = (e) => {
|
||||
const { graph, additions, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
if (additions.length !== 0) g.setNode("addition", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
|
||||
for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
|
||||
const { data, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
|
||||
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
}
|
||||
|
||||
const layoutCfg = (g, { blocks, paths, pc_table, colors }) => {
|
||||
g.setGraph({ rankdir:"TD", font:"monospace" });
|
||||
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
|
||||
// basic blocks render the assembly in nodes
|
||||
for (const [lead, members] of Object.entries(blocks)) {
|
||||
let [width, height, label] = [0, 0, []];
|
||||
for (const m of members) {
|
||||
const text = pc_table[m][0];
|
||||
width = Math.max(width, ctx.measureText(text).width);
|
||||
height += LINE_HEIGHT;
|
||||
const [inst, ...operands] = text.split(" ");
|
||||
label.push([{st:inst+" ", color:"#7aa2f7"}, {st:operands.join(" "), color:"#9aa5ce"}]);
|
||||
}
|
||||
g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" });
|
||||
}
|
||||
// paths become edges between basic blocks
|
||||
for (const [lead, value] of Object.entries(paths)) {
|
||||
for (const [id, color] of Object.entries(value)) g.setEdge(lead, id, {label:{type:"port", text:""}, color:colors[color]});
|
||||
}
|
||||
dagre.layout(g);
|
||||
}
|
||||
|
||||
const layoutUOp = (g, { graph, change }, opts) => {
|
||||
g.setGraph({ rankdir: "LR", font:"sans-serif" });
|
||||
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
|
||||
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
|
||||
for (const [k, {label, src, ref, color, tag }] of Object.entries(graph)) {
|
||||
// adjust node dims by label size (excluding escape codes) + add padding
|
||||
let [width, height] = [0, 0];
|
||||
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
|
||||
width = Math.max(width, ctx.measureText(line).width);
|
||||
height += LINE_HEIGHT;
|
||||
}
|
||||
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest});
|
||||
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag});
|
||||
// add edges
|
||||
const edgeCounts = {}
|
||||
const edgeCounts = {};
|
||||
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
|
||||
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
|
||||
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
|
||||
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
||||
}
|
||||
// optionally hide nodes from the layuot
|
||||
if (!opts.showIndexing) {
|
||||
@@ -31,8 +61,6 @@ onmessage = (e) => {
|
||||
}
|
||||
}
|
||||
dagre.layout(g);
|
||||
// remove additions overlay if it's empty
|
||||
if (!g.node("addition")?.width) g.removeNode("addition");
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
// remove overlay node if it's empty
|
||||
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ from decimal import Decimal
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.helpers import printable, system, TCPServerWithReuse, HTTPRequestHandler
|
||||
from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
||||
from tinygrad.uop.ops import print_uops, range_start, multirange_str
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
|
||||
@@ -53,7 +53,7 @@ class GraphRewriteDetails(TypedDict):
|
||||
graph: dict # JSON serialized UOp for this rewrite step
|
||||
uop: str # strigified UOp for this rewrite step
|
||||
diff: list[str]|None # diff of the single UOp that changed
|
||||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
change: list[int]|None # the new UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
|
||||
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
|
||||
@@ -115,14 +115,14 @@ def _reconstruct(a:int):
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
||||
next_sink = _reconstruct(ctx.sink)
|
||||
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None}
|
||||
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
@@ -248,23 +248,25 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
||||
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
|
||||
if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e
|
||||
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
|
||||
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), \
|
||||
(durations, {k:v[ProfilePMCEvent][0] for k,v in counter_events.items()}))]})
|
||||
if len(counter_events) == 0: return None
|
||||
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
|
||||
run_number = {n:0 for n,_ in counter_events}
|
||||
for k,v in counter_events.items():
|
||||
prg = trace.keys[r].ret if (r:=ref_map.get(k[0])) else None
|
||||
name = prg.name if prg is not None else k[0]
|
||||
run_number[k[0]] += 1
|
||||
for (k, tag),v in counter_events.items():
|
||||
# use the colored name if it exists
|
||||
name = trace.keys[r].ret.name if (r:=ref_map.get(k)) is not None else k
|
||||
run_number[k] += 1
|
||||
steps:list[dict] = []
|
||||
if (pmc:=v.get(ProfilePMCEvent)): steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
|
||||
if (pmc:=v.get(ProfilePMCEvent)):
|
||||
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
|
||||
all_counters[(name, run_number[k], k)] = pmc[0]
|
||||
if (sqtt:=v.get(ProfileSQTTEvent)):
|
||||
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), (k, [*sqtt, prg_events[k[0]], dev_events[sqtt[0].device]])))
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]])))
|
||||
if getenv("SQTT_PARSE"):
|
||||
# run our decoder on startup, we don't use this since it only works on gfx11
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
for e in sqtt: parse_sqtt_print_packets(e.blob)
|
||||
ctxs.append({"name":f"Exec {name} n{run_number[k[0]]}", "steps":steps})
|
||||
ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps})
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
@@ -337,26 +339,6 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
||||
|
||||
# ** Assembly static analyzers
|
||||
|
||||
def get_llvm_mca(asm:str, mtriple:str, mcpu:str) -> dict:
|
||||
target_args = f"-mtriple={mtriple} -mcpu={mcpu}"
|
||||
# disassembly output can include headers / metadata, skip if llvm-mca can't parse those lines
|
||||
data = json.loads(system("llvm-mca -skip-unsupported-instructions=parse-failure --json -"+target_args, input=asm.encode()))
|
||||
cr = data["CodeRegions"][0]
|
||||
resource_labels = [repr(x)[1:-1] for x in data["TargetInfo"]["Resources"]]
|
||||
rows:list = [[instr] for instr in cr["Instructions"]]
|
||||
# add scheduler estimates
|
||||
for info in cr["InstructionInfoView"]["InstructionList"]: rows[info["Instruction"]].append(info["Latency"])
|
||||
# map per instruction resource usage
|
||||
instr_usage:dict[int, dict[int, int]] = {}
|
||||
for d in cr["ResourcePressureView"]["ResourcePressureInfo"]:
|
||||
instr_usage.setdefault(i:=d["InstructionIndex"], {}).setdefault(r:=d["ResourceIndex"], 0)
|
||||
instr_usage[i][r] += d["ResourceUsage"]
|
||||
# last row is the usage summary
|
||||
summary = [{"idx":k, "label":resource_labels[k], "value":v} for k,v in instr_usage.pop(len(rows), {}).items()]
|
||||
max_usage = max([sum(v.values()) for i,v in instr_usage.items() if i<len(rows)], default=0)
|
||||
for i,usage in instr_usage.items(): rows[i].append([[k, v, (v/max_usage)*100] for k,v in usage.items()])
|
||||
return {"rows":rows, "cols":["Instruction", "Latency", {"title":"HW Resources", "labels":resource_labels}], "metadata":[summary]}
|
||||
|
||||
def get_stdout(f: Callable) -> str:
|
||||
buf = io.StringIO()
|
||||
try:
|
||||
@@ -376,6 +358,67 @@ def amd_readelf(lib:bytes) -> list[dict]:
|
||||
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
|
||||
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
|
||||
|
||||
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
||||
from tinygrad.runtime.autogen import llvm
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
llvm.LLVMInitializeAMDGPUTargetInfo()
|
||||
llvm.LLVMInitializeAMDGPUTargetMC()
|
||||
llvm.LLVMInitializeAMDGPUAsmParser()
|
||||
llvm.LLVMInitializeAMDGPUDisassembler()
|
||||
# pass NULL to callbacks
|
||||
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
|
||||
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh.header for sh in sections if sh.name == ".text"), None)
|
||||
assert text is not None, "no .text section found in ELF"
|
||||
off, sz = text.sh_addr, text.sh_size
|
||||
addr_table:dict[int, tuple[str, int]] = {}
|
||||
out = ctypes.create_string_buffer(128)
|
||||
cur_off = off
|
||||
while cur_off < sz + off:
|
||||
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
|
||||
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
|
||||
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
|
||||
cur_off += instr_sz
|
||||
return addr_table
|
||||
|
||||
SOPP_INSTS = {"s_branch", "s_cbranch_scc0", "s_cbranch_scc1", "s_cbranch_vccz", "s_cbranch_vccnz", "s_cbranch_execz", "s_cbranch_execnz"}
|
||||
def parse_branch(asm:str) -> int|None:
|
||||
inst, *operands = asm.split(" ")
|
||||
if inst in SOPP_INSTS:
|
||||
x = int(operands[0]) & 0xffff
|
||||
return (x - 0x10000 if x & 0x8000 else x)*4
|
||||
return None
|
||||
|
||||
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
|
||||
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
|
||||
def amdgpu_cfg(lib:bytes, arch:str) -> dict:
|
||||
# disassemble
|
||||
pc_table = llvm_disasm(arch, lib)
|
||||
# get leaders
|
||||
leaders:set[int] = {next(iter(pc_table))}
|
||||
for pc, (asm, sz) in pc_table.items():
|
||||
if (offset:=parse_branch(asm)) is not None: leaders.update((pc+sz+offset, pc+sz))
|
||||
# build the cfg
|
||||
curr:int|None = None
|
||||
blocks:dict[int, list[int]] = {}
|
||||
paths:dict[int, dict[int, int]] = {}
|
||||
for pc, (asm, sz) in pc_table.items():
|
||||
if pc in leaders:
|
||||
paths[curr:=pc] = {}
|
||||
blocks[pc] = []
|
||||
else: assert curr is not None, f"no basic block found for {pc}"
|
||||
blocks[curr].append(pc)
|
||||
# control flow ends in endpgm
|
||||
if asm == "s_endpgm": break
|
||||
# otherwise a basic block can have exactly one or two paths
|
||||
nx = pc+sz
|
||||
if (offset:=parse_branch(asm)) is not None:
|
||||
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
|
||||
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
|
||||
elif nx in leaders: paths[curr][nx] = UNCOND
|
||||
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
|
||||
|
||||
# ** Main render function to get the complete details about a trace event
|
||||
|
||||
def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
@@ -386,22 +429,19 @@ def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
if fmt == "asm":
|
||||
compiler = Device[data.device].compiler
|
||||
disasm_str = get_stdout(lambda: compiler.disassemble(compiler.compile(data.src)))
|
||||
from tinygrad.runtime.support.compiler_cpu import llvm, LLVMCompiler
|
||||
if isinstance(compiler, LLVMCompiler):
|
||||
return get_llvm_mca(disasm_str, ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode(),
|
||||
ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode())
|
||||
metadata:list = []
|
||||
ret:dict = {"src":disasm_str}
|
||||
if data.device.startswith("AMD"):
|
||||
with soft_err(lambda err: metadata.append(err)):
|
||||
metadata.append(amd_readelf(compiler.compile(data.src)))
|
||||
return {"src":disasm_str, "lang":"amdgpu" if data.device.startswith("AMD") else None, "metadata":metadata}
|
||||
with soft_err(lambda err: ret.update(err)):
|
||||
metadata = amd_readelf(lib:=compiler.compile(data.src))
|
||||
ret = {"data":amdgpu_cfg(lib, getattr(compiler, "arch")), "metadata":[metadata]}
|
||||
return ret
|
||||
if fmt == "all-pmc":
|
||||
durations, pmc = data
|
||||
ret:dict = {"cols":{}, "rows":[]}
|
||||
for (prg,_),events in pmc.items():
|
||||
ret = {"cols":{}, "rows":[]}
|
||||
for (name, n, k),events in data[1].items():
|
||||
pmc_table = unpack_pmc(events)
|
||||
ret["cols"].update([(r[0], None) for r in pmc_table["rows"]])
|
||||
ret["rows"].append((prg, durations[prg].pop(0), *[r[1] for r in pmc_table["rows"]]))
|
||||
ret["rows"].append((name, durations[k][n-1], *[r[1] for r in pmc_table["rows"]]))
|
||||
ret["cols"] = ["Kernel", "Duration", *ret["cols"]]
|
||||
return ret
|
||||
if fmt == "prg-pmc": return unpack_pmc(data[0])
|
||||
|
||||
Reference in New Issue
Block a user