don't place consts early + fixes from anthropic challenge (#14286)

* don't place consts early

* add anthropic challenge

* with ref

* do we still have to devectorize bools?

* tests pass

* just WHERE

* fine, revert that

* fine, revert

* only index

* z3 validator doesn't support vectorized

* Revert "z3 validator doesn't support vectorized"

This reverts commit 1b7930ecb3.

* z3 not for vec

* no spec

* VLIWRenderer

* loop unrolling

* better comments

* cleanups

* skip cast

* renderer

* cleanups

* prints

* no hack

* hacks

* bump to 11

* reg warning

* lil clean

* cleaner renderer
This commit is contained in:
George Hotz
2026-01-23 09:48:39 +08:00
committed by GitHub
parent 0903782bc0
commit 52b989c6c8
8 changed files with 216 additions and 12 deletions

View File

@@ -561,7 +561,7 @@ jobs:
- name: openpilot compile3 0.10.1 driving_policy
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.10.1 dmonitoring
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=10 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
- name: benchmark MobileNetV2 on DSP
run: |
# generate quantized weights

View File

@@ -0,0 +1,196 @@
from tinygrad import Tensor, dtypes, Context, getenv, UOp, fetch
from tinygrad.uop.ops import Ops, PatternMatcher, UPat
from tinygrad.uop.symbolic import symbolic
from tinygrad.codegen import Renderer
from tinygrad.codegen.opt import Opt, OptOps
# ************************* implementation of the problem ************************
def myhash(a: Tensor) -> Tensor:
a = (a + 0x7ED55D16) + (a << 12)
a = (a ^ 0xC761C23C) ^ (a >> 19)
a = (a + 0x165667B1) + (a << 5)
a = (a + 0xD3A2646C) ^ (a << 9)
a = (a + 0xFD7046C5) + (a << 3)
a = (a ^ 0xB55A4F09) ^ (a >> 16)
return a
def select_with_where_tree(values: Tensor, relative_idx: Tensor) -> Tensor:
n = values.shape[0]
if n == 1: return values[0].expand(relative_idx.shape)
mid = n // 2
left = select_with_where_tree(values[:mid], relative_idx)
right = select_with_where_tree(values[mid:], relative_idx - mid)
go_left = relative_idx < mid
return go_left.where(left, right)
def tree_traversal(forest: Tensor, val: Tensor, height: int, rounds: int, where_tree_threshold=3) -> Tensor:
# All walkers start at idx=0
idx = Tensor.zeros(val.shape, device=val.device, dtype=dtypes.uint32)
for r in range(rounds):
level = r % (height + 1)
level_start = (1 << level) - 1
level_size = 1 << level
if level == 0:
# At root (level 0), all walkers are at idx=0
# No gather needed, just broadcast the root value
node_val = forest[0].expand(val.shape)
idx = idx * 0 # Reset to 0
elif level <= where_tree_threshold:
# Small level: use where-tree
level_values = forest[level_start : level_start + level_size]
relative_idx = (idx - level_start)
node_val = select_with_where_tree(level_values, relative_idx)
else:
# Large level: use gather
node_val = forest.gather(0, idx)
val = myhash(val ^ node_val)
idx = (idx << 1) + (1 + (val & 1))
# No wrap check needed! At round 10 (level becomes 0), we reset idx above.
return val.contiguous(arg=(Opt(OptOps.UPCAST, 0, 8),))
# ************************* renderer for VLIW machine *************************
def loop_unrolling(sink:UOp):
rng = [x for x in sink.toposort() if x.op is Ops.RANGE]
if len(rng) == 0: return None
print(f"unrolling loop with size {rng[0].vmax+1}")
unrolled_sinks = [sink.substitute({rng[0]:rng[0].const_like(i)}).src[0] for i in range(rng[0].vmax+1)]
return UOp.sink(*unrolled_sinks, arg=sink.arg)
global_addrs = []
vliw_prepare = PatternMatcher([
# loop unrolling (should be a part of tinygrad)
(UPat(Ops.SINK, name="sink"), loop_unrolling),
# cast is fake
(UPat(Ops.CAST, name="c"), lambda c: c.src[0]),
# rewrites to hardcode the addresses in memory
(UPat(Ops.DEFINE_GLOBAL, name="dg"), lambda dg: UOp.const(dtypes.uint, global_addrs[dg.arg])),
# INDEX is just plus
(UPat(Ops.INDEX, name="i"), lambda i: i.src[0]+i.src[1]),
])+symbolic
class VLIWRenderer(Renderer):
has_local = False # TODO: this should be the default / cleaned up
# this says this backend supports MULACC + more. decompositions uses this
code_for_op: dict = {Ops.MULACC: None, Ops.ADD: "+", Ops.MUL: "*",
Ops.XOR: "^", Ops.AND: "&", Ops.OR: "|",
Ops.SHL: "<<", Ops.SHR: ">>", Ops.CMPLT: "<"}
# this matcher runs while still in graph form
pre_matcher = vliw_prepare
def render(self, uops:list[UOp]):
# TODO: this is a minimal renderer. for low cycle count, make it good
# to get speed, you need to add VLIW packing
# to get under 1536 regs, you need to add a register allocator
# we left the fun parts to you
print(f"rendering with {len(uops)} uops")
reg, inst = 0, []
r: dict[UOp, int] = {}
for u in uops:
assert u.dtype.count in (1,8), "dtype count must be 1 or 8"
# dumb register allocator
if u.op not in {Ops.STORE, Ops.SINK, Ops.GEP}:
r[u] = reg
reg += u.dtype.count
# render UOps to instructions
match u.op:
case Ops.SINK:
inst.append({"flow": [("halt",)]})
case Ops.CONST:
inst.append({"load": [("const", r[u], u.arg)]})
case Ops.GEP:
# a GEP is just an alias to a special register in the vector
r[u] = r[u.src[0]] + u.arg[0]
case Ops.VECTORIZE:
if all(s == u.src[0] for s in u.src):
# if all sources are the same, we can broadcast
inst.append({"valu": [("vbroadcast", r[u], r[u.src[0]])]})
else:
# this is a copy into a contiguous chunk of registers
inst.extend({"flow": [("add_imm", r[u]+i, r[s], 0)]} for i,s in enumerate(u.src) if r[s] != r[u]+i)
case Ops.LOAD:
op = "vload" if u.dtype.count > 1 else "load"
inst.append({"load": [(op, r[u], r[u.src[0]])]})
case Ops.STORE:
op = "vstore" if u.src[1].dtype.count > 1 else "store"
inst.append({"store": [(op, r[u.src[0]], r[u.src[1]])]})
case Ops.MULACC:
assert u.dtype.count == 8
inst.append({"valu": [("multiply_add", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]})
case Ops.WHERE:
assert u.dtype.count == 8
inst.append({"flow": [("vselect", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]})
case _ if u.op in self.code_for_op:
cat = "valu" if u.dtype.count > 1 else "alu"
inst.append({cat: [(self.code_for_op[u.op], r[u], r[u.src[0]], r[u.src[1]])]})
case _:
raise NotImplementedError(f"unhandled op {u.op}")
return repr(inst)
# ************************* test and render *************************
import sys, types
PROBLEM_URL = "https://raw.githubusercontent.com/anthropics/original_performance_takehome/refs/heads/main/tests/frozen_problem.py"
sys.modules["problem"] = problem = types.ModuleType("problem")
exec(fetch(PROBLEM_URL).read_text(), problem.__dict__)
if __name__ == "__main__":
batch_size = getenv("BS", 256)
height = 10
rounds = getenv("ROUNDS", 16)
# build problem
tree = problem.Tree.generate(height)
inp = problem.Input.generate(tree, batch_size, rounds)
mem = problem.build_mem_image(tree, inp)
global_addrs.extend([mem[6], mem[6], mem[4]]) # output, input, forest
# *** verify the kernel in tinygrad compared to reference ***
forest_t = Tensor(tree.values, dtype=dtypes.uint32)
val_t = Tensor(inp.values, dtype=dtypes.uint32)
if getenv("VERIFY", 1):
# verify on normal tinygrad device
with Context(PCONTIG=2):
out = tree_traversal(forest_t, val_t, height, rounds)
val_out = out.tolist()
problem.reference_kernel(tree, inp)
assert val_out == inp.values
print("verification passed")
# *** render to device ***
from tinygrad.codegen import get_program
with Context(PCONTIG=2, DEVECTORIZE=2, SPEC=0):
out = tree_traversal(forest_t, val_t, height, rounds)
sink = out.schedule()[-1].ast
prg = get_program(sink, VLIWRenderer())
# *** run on Machine and compare ***
# NOTE: the scratch size needs to be reduced to 1536 when you have a register allocator
src = eval(prg.src)
max_regs = max(t[1] for instr in src for v in instr.values() for t in v if len(t) > 1) + 8
print(f"{max_regs:5d} regs used" + ("" if max_regs <= 1536 else " <-- WARNING: TOO MANY REGISTERS, MUST BE <= 1536"))
machine = problem.Machine(mem, src, problem.DebugInfo(scratch_map={}), n_cores=1, trace=False, scratch_size=max_regs)
machine.run()
print(f"ran for {machine.cycle:5d} cycles" + ("" if machine.cycle <= 1363 else " <-- EVEN CLAUDE GOT 1363"))
# compare to reference
ref_mem = mem.copy()
for _ in problem.reference_kernel2(ref_mem, {}): pass
assert machine.mem[mem[6]:mem[6]+mem[2]] == ref_mem[mem[6]:mem[6]+mem[2]]
print("compare passed!")

View File

@@ -84,10 +84,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing, ctx=ren.device, name="lower all index dtypes")
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes")
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# optional pre matcher

View File

@@ -243,11 +243,17 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
cnt = cast.dtype.count
vcnt = cast.dtype.vcount
precnt = bcast.dtype.vcount
input_gep = bcast.arg if bcast.op is Ops.GEP else ([0]*precnt)
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
sum_arg = tuple(flatten([[i+y for y in input_gep] for i in range(cnt)]))
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg), ptr=True)
# TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school.
if bcast.op is Ops.GEP:
gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)]))
sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)]))
else:
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg)
return buf.broadcast(cnt*precnt).index(new_idx, ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),

View File

@@ -30,7 +30,6 @@ def linearize(sink:UOp) -> list[UOp]:
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
case Ops.DEFINE_LOCAL: priority = -18
case Ops.DEFINE_REG: priority = -17
case Ops.CONST: priority = -10 # early consts
case Ops.LOAD: priority = -1 # place loads early
case Ops.STORE: priority = 1 # place stores late
case Ops.RANGE: priority = 5 # placing RANGE is good

View File

@@ -15,8 +15,9 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
# TODO: validate these
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
for x in idx.toposort() | gate.toposort():
if x.op is Ops.BITCAST or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
if x.op in {Ops.BITCAST, Ops.VECTORIZE, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
# if all is good and IGNORE_OOB=0, validate with z3
from tinygrad.uop.validate import validate_index_with_z3

View File

@@ -163,8 +163,8 @@ gep_pushing = PatternMatcher([
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
# GEP in order is removed
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
# push all GEPs through ALUs (fix arange stuff)
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'),
# push all GEPs through ALUs for index (TODO: remove this)
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.index, name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None),
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)

View File

@@ -48,7 +48,9 @@ def uops_to_z3(solver, *uops: UOp) -> list[z3.ExprRef]:
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1]
z3map: dict[UOp, z3.ExprRef] = {}
for i,u in enumerate(lst):
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_renderer.rewrite(u, ctx=(solver, z3map)))
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver, z3map))
if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3")
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten)
if constraint is not None: solver.add(constraint)
z3map[u] = new_u
assert all(u in z3map for u in uops), "UOp failed to rewrite to z3!"