Merge branch 'master' into test_fa

This commit is contained in:
George Hotz
2025-10-17 22:41:31 +08:00
committed by GitHub
2 changed files with 32 additions and 17 deletions

View File

@@ -1242,7 +1242,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
G, V, H = G.detach(), V.detach(), H.detach()
X.grad = norm_coefficient * X.detach() + G
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
opt.m, opt.v, opt.lr = [V], [H], R
# NOTE: FUSE_OPTIM can change shapes of m and v
opt.m, opt.v, opt.lr = [V.reshape(opt.m[0].shape)], [H.reshape(opt.v[0].shape)], R
# need no-op for m_hat and v_hat if T == 0
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
else:

View File

@@ -218,11 +218,10 @@ async function renderProfiler() {
if (eventType === EventTypes.TIMELINE) {
const levelHeight = baseHeight-padding;
const levels = [];
data.tracks.set(k, { shapes, visible, offsetY });
data.tracks.set(k, { shapes, visible, offsetY, pcolor:"#9ea2ad" });
let colorKey, ref;
for (let j=0; j<eventsLen; j++) {
const e = {name:strings[u32()], ref:optional(u32()), key:optional(u32()), st:u32(), dur:f32(), info:strings[u32()] || null};
if (e.key != null) shapeMap.set(e.key, e);
// find a free level to put the event
let depth = levels.findIndex(levelEt => e.st >= levelEt);
const et = e.st+Math.trunc(e.dur);
@@ -242,7 +241,18 @@ async function renderProfiler() {
const stepIdx = ctxs[ref.ctx+1].steps.findIndex((s, i) => i >= start && s.name == e.name);
if (stepIdx !== -1) { ref.step = stepIdx; shapeRef = ref; }
}
const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), ...shapeRef };
const html = document.createElement("div");
html.appendChild(tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node());
if (e.info != null) html.appendChild(document.createElement("p")).innerText = "\n"+e.info;
if (shapeRef != null) {
const p = html.appendChild(document.createElement("p"));
p.innerText = "\nView Codegen Rewrite"; p.style.cursor = "pointer";
p.onclick = () => setCtxWithHistory(shapeRef.ctx, shapeRef.step);
}
// tiny device events go straight to the rewrite rule
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
const arg = { tooltipText:colored(e.name).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), html, key, ...shapeRef };
if (e.key != null) shapeMap.set(e.key, arg);
// offset y by depth
shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label, fillColor });
}
@@ -262,7 +272,7 @@ async function renderProfiler() {
x += 1; y += nbytes; valueMap.set(ts, y);
} else {
const free = buf_shapes.get(key);
free.users = Array.from({ length: u32() }, () => ({...shapeMap.get(u32()), repr:strings[u32()], num:u8(), mode:u8()}));
free.users = Array.from({ length: u32() }, () => ({shape:shapeMap.get(u32()), repr:strings[u32()], num:u8(), mode:u8()}));
timestamps.push(ts); valueMap.set(ts, y);
x += 1; y -= free.nbytes;
free.x.push(x);
@@ -286,12 +296,13 @@ async function renderProfiler() {
if (users != null) rows.push(["Users", users.length]);
const info = html.appendChild(tabulate(rows).node());
for (let u=0; u<users?.length; u++) {
const p = html.appendChild(document.createElement("p")); p.style.marginTop = "4px"; p.style.cursor = "pointer";
const { repr, num, mode, info, ref } = users[u]; p.appendChild(colored(`[${u}] ${repr} ${mode == 2 ? 'read+write' : mode == 1 ? 'write' : 'read'}@data${num}`));
const metadata = info?.split("\n")[1]
const p = html.appendChild(document.createElement("p")); p.style.marginTop = "4px";
const { repr, num, mode, shape } = users[u]; p.appendChild(colored(`[${u}] ${repr} ${mode == 2 ? 'read+write' : mode == 1 ? 'write' : 'read'}@data${num}`));
const metadata = shape?.tooltipText?.split("\n").at(-1);
if (metadata != null) p.appendChild(document.createElement("span")).innerText = "\n"+metadata;
p.onclick = () => {
if (ref != null) setCtxWithHistory(ref);
if (shape != null) {
p.style.cursor = "pointer";
p.onclick = () => focusShape(shape);
}
}
const arg = {tooltipText:info.outerHTML, html, key:`${k}-${num}`};
@@ -317,7 +328,7 @@ async function renderProfiler() {
sum.x.push(allX[i], allX[i+1]);
const y = maxY.get(allX[i]); sum.y1.push(y, y); sum.y0.push(base0, base0);
}
data.tracks.set(k, { shapes:[sum], visible, offsetY, height, peak, scaleFactor:maxheight*4/height, views:[[sum], shapes], valueMap });
data.tracks.set(k, { shapes:[sum], visible, offsetY, pcolor:"#c9a8ff", height, peak, scaleFactor:maxheight*4/height, views:[[sum], shapes], valueMap });
div.style("height", height+padding+"px").style("cursor", "pointer").on("click", (e) => {
const newFocus = e.currentTarget.id === focusedDevice ? null : e.currentTarget.id;
let offset = 0;
@@ -346,7 +357,7 @@ async function renderProfiler() {
xscale.domain(visibleX);
// draw shapes
const paths = [];
for (const [_, { offsetY, shapes, visible, valueMap }] of data.tracks) {
for (const [_, { offsetY, shapes, visible, valueMap, pcolor }] of data.tracks) {
visible.length = 0;
for (const e of shapes) {
const p = new Path2D();
@@ -385,7 +396,7 @@ async function renderProfiler() {
lw += e.label[li].width;
}
}
if (focusedShape?.key && e.arg?.key === focusedShape.key) { paths.push(p); }
if (focusedShape?.key && e.arg?.key === focusedShape.key) { paths.push([p, pcolor]); }
}
}
// draw axes
@@ -415,7 +426,7 @@ async function renderProfiler() {
drawLine(ctx, [x, x], [0, canvas.clientHeight], { color:m.color });
ctx.fillText(m.name, x+2, 1);
}
for (const p of paths) { ctx.lineWidth = 1.4; ctx.strokeStyle = "#c9a8ff"; ctx.stroke(p); }
for (const [p, color] of paths) { ctx.lineWidth = 1.4; ctx.strokeStyle = color; ctx.stroke(p); }
}
function resize() {
@@ -452,12 +463,15 @@ async function renderProfiler() {
}
}
function focusShape(shape) {
focusedShape = shape; render(zoomLevel);
return document.querySelector(".metadata").replaceChildren(shape?.html ?? "");
}
canvas.addEventListener("click", e => {
e.preventDefault();
const foundRect = findRectAtPosition(e.clientX, e.clientY);
if (foundRect?.step != null) return setCtxWithHistory(foundRect.ctx, foundRect.step);
if (foundRect?.key != focusedShape?.key) { focusedShape = foundRect; render(zoomLevel); }
return document.querySelector(".metadata").replaceChildren(foundRect?.html ?? "");
if (foundRect?.step != null && foundRect?.key == null) { return setCtxWithHistory(foundRect.ctx, foundRect.step); }
if (foundRect?.key != focusedShape?.key) { focusShape(foundRect); }
});
canvas.addEventListener("mousemove", e => {